diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index a13fddc2..d44ab146 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -1109,11 +1109,11 @@ impl Client { } #[cfg(feature = "encryption")] - async fn send_to_device(&self, request: OwnedToDeviceRequest) -> Result { + async fn send_to_device(&self, request: &OwnedToDeviceRequest) -> Result { let request = ToDeviceRequest { - event_type: request.event_type, + event_type: request.event_type.clone(), txn_id: &request.txn_id, - messages: request.messages, + messages: request.messages.clone(), }; self.send(request).await @@ -1232,22 +1232,22 @@ impl Client { #[cfg(feature = "encryption")] { for r in self.base_client.outgoing_requests().await { - match r.request { + match r.request() { OutgoingRequests::KeysQuery(request) => { - if let Err(e) = self.keys_query(&r.request_id, request).await { + if let Err(e) = self.keys_query(r.request_id(), request).await { warn!("Error while querying device keys {:?}", e); } } OutgoingRequests::KeysUpload(request) => { - if let Err(e) = self.keys_upload(&r.request_id, request).await { + if let Err(e) = self.keys_upload(&r.request_id(), request).await { warn!("Error while querying device keys {:?}", e); } } OutgoingRequests::ToDeviceRequest(request) => { if let Ok(resp) = self.send_to_device(request).await { self.base_client - .mark_request_as_sent(&r.request_id, &resp) + .mark_request_as_sent(&r.request_id(), &resp) .await .unwrap(); } @@ -1328,7 +1328,7 @@ impl Client { .expect("Keys don't need to be uploaded"); for request in requests.drain(..) { - self.send_to_device(request).await?; + self.send_to_device(&request).await?; } Ok(()) @@ -1349,7 +1349,7 @@ impl Client { async fn keys_upload( &self, request_id: &Uuid, - request: upload_keys::Request, + request: &upload_keys::Request, ) -> Result { debug!( "Uploading encryption keys device keys: {}, one-time-keys: {}", @@ -1357,7 +1357,7 @@ impl Client { request.one_time_keys.as_ref().map_or(0, |k| k.len()) ); - let response = self.send(request).await?; + let response = self.send(request.clone()).await?; self.base_client .mark_request_as_sent(request_id, &response) .await?; @@ -1382,11 +1382,11 @@ impl Client { async fn keys_query( &self, request_id: &Uuid, - request: get_keys::IncomingRequest, + request: &get_keys::IncomingRequest, ) -> Result { let request = get_keys::Request { timeout: None, - device_keys: request.device_keys, + device_keys: request.device_keys.clone(), token: None, }; diff --git a/matrix_sdk/src/sas.rs b/matrix_sdk/src/sas.rs index 3092c5aa..cadaddd1 100644 --- a/matrix_sdk/src/sas.rs +++ b/matrix_sdk/src/sas.rs @@ -56,7 +56,7 @@ impl Sas { /// Cancel the interactive verification flow. pub async fn cancel(&self) -> Result<()> { - if let Some(req) = self.inner.cancel() { + if let Some((_, req)) = self.inner.cancel() { let request = ToDeviceRequest { event_type: req.event_type, txn_id: &req.txn_id, diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index d5554002..d057bd78 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1748,27 +1748,6 @@ impl BaseClient { } } - /// Get the to-device requests that need to be sent out. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn outgoing_to_device_requests(&self) -> Vec { - self.olm - .lock() - .await - .as_ref() - .map(|o| o.outgoing_to_device_requests()) - .unwrap_or_default() - } - - /// Mark an outgoing to-device requests as sent. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn mark_to_device_request_as_sent(&self, request_id: &str) { - if let Some(olm) = self.olm.lock().await.as_ref() { - olm.mark_to_device_request_as_sent(request_id) - } - } - /// Get a `Sas` verification object with the given flow id. /// /// # Arguments diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 79957a5b..b9755a03 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -223,18 +223,20 @@ impl OlmMachine { if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { request_id: Uuid::new_v4(), - request: r.into(), + request: Arc::new(r.into()), }) { requests.push(r); } if let Some(r) = self.users_for_key_query().await.map(|r| OutgoingRequest { request_id: Uuid::new_v4(), - request: r.into(), + request: Arc::new(r.into()), }) { requests.push(r); } + requests.append(&mut self.outgoing_to_device_requests()); + requests } @@ -255,7 +257,7 @@ impl OlmMachine { self.receive_keys_claim_response(response).await?; } IncomingResponse::ToDevice(_) => { - self.mark_to_device_request_as_sent(&request_id.to_string()); + self.mark_to_device_request_as_sent(&request_id); } }; @@ -1234,12 +1236,12 @@ impl OlmMachine { } /// Get the to-device requests that need to be sent out. - pub fn outgoing_to_device_requests(&self) -> Vec { + fn outgoing_to_device_requests(&self) -> Vec { self.verification_machine.outgoing_to_device_requests() } /// Mark an outgoing to-device requests as sent. - pub fn mark_to_device_request_as_sent(&self, request_id: &str) { + fn mark_to_device_request_as_sent(&self, request_id: &Uuid) { self.verification_machine.mark_requests_as_sent(request_id); } @@ -1538,8 +1540,9 @@ pub(crate) mod test { use tempfile::tempdir; use crate::{ - machine::OlmMachine, verification::test::request_to_event, verify_json, EncryptionSettings, - ReadOnlyDevice, + machine::OlmMachine, + verification::test::{outgoing_request_to_event, request_to_event}, + verify_json, EncryptionSettings, ReadOnlyDevice, }; use matrix_sdk_common::{ @@ -2169,7 +2172,7 @@ pub(crate) mod test { .outgoing_to_device_requests() .iter() .next() - .map(|r| request_to_event(alice.user_id(), &r)) + .map(|r| outgoing_request_to_event(alice.user_id(), r)) .unwrap(); bob.handle_verification_event(&mut event).await; @@ -2177,7 +2180,7 @@ pub(crate) mod test { .outgoing_to_device_requests() .iter() .next() - .map(|r| request_to_event(bob.user_id(), &r)) + .map(|r| outgoing_request_to_event(bob.user_id(), r)) .unwrap(); alice.handle_verification_event(&mut event).await; diff --git a/matrix_sdk_crypto/src/olm/group_sessions.rs b/matrix_sdk_crypto/src/olm/group_sessions.rs index 40949632..a677380c 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions.rs @@ -476,7 +476,10 @@ impl std::fmt::Debug for OutboundGroupSession { #[cfg(test)] mod test { - use std::{thread::sleep, time::Duration}; + use std::{ + sync::Arc, + time::{Duration, Instant}, + }; use matrix_sdk_common::{ events::room::message::{MessageEventContent, TextMessageEventContent}, @@ -487,6 +490,7 @@ mod test { use crate::Account; #[tokio::test] + #[cfg(not(target_os = "macos"))] async fn expiration() { let settings = EncryptionSettings { rotation_period_msgs: 1, @@ -512,13 +516,13 @@ mod test { ..Default::default() }; - let (session, _) = account + let (mut session, _) = account .create_group_session_pair(&room_id!("!test_room:example.org"), settings) .await .unwrap(); assert!(!session.expired()); - sleep(Duration::from_millis(110)); + session.creation_time = Arc::new(Instant::now() - Duration::from_secs(60 * 60)); assert!(session.expired()); } } diff --git a/matrix_sdk_crypto/src/requests.rs b/matrix_sdk_crypto/src/requests.rs index 07816601..0b102076 100644 --- a/matrix_sdk_crypto/src/requests.rs +++ b/matrix_sdk_crypto/src/requests.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use matrix_sdk_common::{ api::r0::{ keys::{ @@ -49,6 +51,12 @@ impl From for OutgoingRequests { } } +impl From for OutgoingRequests { + fn from(request: ToDeviceRequest) -> Self { + OutgoingRequests::ToDeviceRequest(request) + } +} + /// TODO #[derive(Debug)] pub enum IncomingResponse<'a> { @@ -87,11 +95,23 @@ impl<'a> From<&'a KeysClaimResponse> for IncomingResponse<'a> { } /// TODO -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OutgoingRequest { /// The unique id of a request, needs to be passed when receiving a /// response. - pub request_id: Uuid, + pub(crate) request_id: Uuid, /// TODO - pub request: OutgoingRequests, + pub(crate) request: Arc, +} + +impl OutgoingRequest { + /// Get the unique id of this request. + pub fn request_id(&self) -> &Uuid { + &self.request_id + } + + /// Get the underlying outgoing request. + pub fn request(&self) -> &OutgoingRequests { + &self.request + } } diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index a2699127..3590dced 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -22,17 +22,18 @@ use matrix_sdk_common::{ api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, events::{AnyToDeviceEvent, AnyToDeviceEventContent}, identifiers::{DeviceId, UserId}, + uuid::Uuid, }; use super::sas::{content_to_request, Sas}; -use crate::{Account, CryptoStore, CryptoStoreError, ReadOnlyDevice}; +use crate::{requests::OutgoingRequest, Account, CryptoStore, CryptoStoreError, ReadOnlyDevice}; #[derive(Clone, Debug)] pub struct VerificationMachine { account: Account, pub(crate) store: Arc>, verifications: Arc>, - outgoing_to_device_messages: Arc>, + outgoing_to_device_messages: Arc>, } impl VerificationMachine { @@ -58,7 +59,7 @@ impl VerificationMachine { identity, ); - let request = content_to_request( + let (_, request) = content_to_request( device.user_id(), device.device_id(), AnyToDeviceEventContent::KeyVerificationStart(content), @@ -81,10 +82,14 @@ impl VerificationMachine { recipient_device: &DeviceId, content: AnyToDeviceEventContent, ) { - let request = content_to_request(recipient, recipient_device, content); + let (request_id, request) = content_to_request(recipient, recipient_device, content); - self.outgoing_to_device_messages - .insert(request.txn_id.clone(), request); + let request = OutgoingRequest { + request_id: request_id.clone(), + request: Arc::new(request.into()), + }; + + self.outgoing_to_device_messages.insert(request_id, request); } fn receive_event_helper(&self, sas: &Sas, event: &mut AnyToDeviceEvent) { @@ -93,19 +98,15 @@ impl VerificationMachine { } } - pub fn mark_requests_as_sent(&self, uuid: &str) { + pub fn mark_requests_as_sent(&self, uuid: &Uuid) { self.outgoing_to_device_messages.remove(uuid); } - pub fn outgoing_to_device_requests(&self) -> Vec { + pub fn outgoing_to_device_requests(&self) -> Vec { #[allow(clippy::map_clone)] self.outgoing_to_device_messages .iter() - .map(|r| OwnedToDeviceRequest { - event_type: r.event_type.clone(), - txn_id: r.txn_id.clone(), - messages: r.messages.clone(), - }) + .map(|r| (*r).clone()) .collect() } @@ -115,7 +116,13 @@ impl VerificationMachine { for sas in self.verifications.iter() { if let Some(r) = sas.cancel_if_timed_out() { - self.outgoing_to_device_messages.insert(r.txn_id.clone(), r); + self.outgoing_to_device_messages.insert( + r.0.clone(), + OutgoingRequest { + request_id: r.0, + request: Arc::new(r.1.into()), + }, + ); } } } @@ -184,7 +191,13 @@ impl VerificationMachine { if s.is_done() && !s.mark_device_as_verified().await? { if let Some(r) = s.cancel() { - self.outgoing_to_device_messages.insert(r.txn_id.clone(), r); + self.outgoing_to_device_messages.insert( + r.0.clone(), + OutgoingRequest { + request_id: r.0, + request: Arc::new(r.1.into()), + }, + ); } } }; @@ -211,6 +224,7 @@ mod test { use super::{Sas, VerificationMachine}; use crate::{ + requests::OutgoingRequests, store::memorystore::MemoryStore, verification::test::{get_content_from_request, wrap_any_to_device_content}, Account, CryptoStore, ReadOnlyDevice, @@ -293,10 +307,15 @@ mod test { .next() .unwrap(); - let txn_id = request.txn_id.clone(); + let txn_id = request.request_id().clone(); - let mut event = - wrap_any_to_device_content(alice.user_id(), get_content_from_request(&request)); + let r = if let OutgoingRequests::ToDeviceRequest(r) = request.request() { + r + } else { + panic!("Invalid request type"); + }; + + let mut event = wrap_any_to_device_content(alice.user_id(), get_content_from_request(r)); drop(request); alice_machine.mark_requests_as_sent(&txn_id); diff --git a/matrix_sdk_crypto/src/verification/mod.rs b/matrix_sdk_crypto/src/verification/mod.rs index 9e503f4e..ab79c0f7 100644 --- a/matrix_sdk_crypto/src/verification/mod.rs +++ b/matrix_sdk_crypto/src/verification/mod.rs @@ -20,6 +20,7 @@ pub use sas::Sas; #[cfg(test)] pub(crate) mod test { + use crate::requests::{OutgoingRequest, OutgoingRequests}; use serde_json::Value; use matrix_sdk_common::{ @@ -36,6 +37,16 @@ pub(crate) mod test { wrap_any_to_device_content(sender, content) } + pub(crate) fn outgoing_request_to_event( + sender: &UserId, + request: &OutgoingRequest, + ) -> AnyToDeviceEvent { + match request.request() { + OutgoingRequests::ToDeviceRequest(r) => request_to_event(sender, r), + _ => panic!("Unsupported outgoing request"), + } + } + pub(crate) fn wrap_any_to_device_content( sender: &UserId, content: AnyToDeviceEventContent, diff --git a/matrix_sdk_crypto/src/verification/sas/helpers.rs b/matrix_sdk_crypto/src/verification/sas/helpers.rs index 1c66fcd4..259248e1 100644 --- a/matrix_sdk_crypto/src/verification/sas/helpers.rs +++ b/matrix_sdk_crypto/src/verification/sas/helpers.rs @@ -464,7 +464,7 @@ pub fn content_to_request( recipient: &UserId, recipient_device: &DeviceId, content: AnyToDeviceEventContent, -) -> OwnedToDeviceRequest { +) -> (Uuid, OwnedToDeviceRequest) { let mut messages = BTreeMap::new(); let mut user_messages = BTreeMap::new(); @@ -483,11 +483,16 @@ pub fn content_to_request( _ => unreachable!(), }; - OwnedToDeviceRequest { - txn_id: Uuid::new_v4().to_string(), - event_type, - messages, - } + let request_id = Uuid::new_v4(); + + ( + request_id, + OwnedToDeviceRequest { + txn_id: request_id.to_string(), + event_type, + messages, + }, + ) } #[cfg(test)] diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 9bf9d153..3e4db5e3 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -31,6 +31,7 @@ use matrix_sdk_common::{ AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, }, identifiers::{DeviceId, UserId}, + uuid::Uuid, }; use crate::{ @@ -168,7 +169,7 @@ impl Sas { pub fn accept(&self) -> Option { self.inner.lock().unwrap().accept().map(|c| { let content = AnyToDeviceEventContent::KeyVerificationAccept(c); - self.content_to_request(content) + self.content_to_request(content).1 }) } @@ -194,7 +195,7 @@ impl Sas { // else branch and only after the identity was verified as well. We // dont' want to verify one without the other. if !self.mark_device_as_verified().await? { - return Ok(self.cancel()); + return Ok(self.cancel().map(|r| r.1)); } else { self.mark_identity_as_verified().await?; } @@ -202,7 +203,7 @@ impl Sas { Ok(content.map(|c| { let content = AnyToDeviceEventContent::KeyVerificationMac(c); - self.content_to_request(content) + self.content_to_request(content).1 })) } @@ -327,7 +328,7 @@ impl Sas { /// /// Returns None if the `Sas` object is already in a canceled state, /// otherwise it returns a request that needs to be sent out. - pub fn cancel(&self) -> Option { + pub fn cancel(&self) -> Option<(Uuid, OwnedToDeviceRequest)> { let mut guard = self.inner.lock().unwrap(); let sas: InnerSas = (*guard).clone(); let (sas, content) = sas.cancel(CancelCode::User); @@ -336,7 +337,7 @@ impl Sas { content.map(|c| self.content_to_request(c)) } - pub(crate) fn cancel_if_timed_out(&self) -> Option { + pub(crate) fn cancel_if_timed_out(&self) -> Option<(Uuid, OwnedToDeviceRequest)> { if self.is_canceled() || self.is_done() { None } else if self.timed_out() { @@ -410,7 +411,7 @@ impl Sas { pub(crate) fn content_to_request( &self, content: AnyToDeviceEventContent, - ) -> OwnedToDeviceRequest { + ) -> (Uuid, OwnedToDeviceRequest) { content_to_request(self.other_user_id(), self.other_device_id(), content) } }