diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 674ae84d..5fe8dbc3 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -38,7 +38,7 @@ use tracing::{error, info, instrument}; use matrix_sdk_base::{BaseClient, BaseClientConfig, Room, Session, StateStore}; #[cfg(feature = "encryption")] -use matrix_sdk_base::CryptoStoreError; +use matrix_sdk_base::{CryptoStoreError, OutgoingRequests}; use matrix_sdk_common::{ api::r0::{ @@ -1235,29 +1235,27 @@ impl Client { #[cfg(feature = "encryption")] { - if self.base_client.should_upload_keys().await { - let response = self.keys_upload().await; + for r in self.base_client.outgoing_requests().await { + match r.request { + OutgoingRequests::KeysQuery(request) => { + if let Err(e) = self.keys_query(&r.request_id, request).await { + warn!("Error while querying device keys {:?}", e); + } + } - if let Err(e) = response { - warn!("Error while uploading E2EE keys {:?}", e); - } - } - - if self.base_client.should_query_keys().await { - let response = self.keys_query().await; - - if let Err(e) = response { - warn!("Error while querying device keys {:?}", e); - } - } - - for request in self.base_client.outgoing_to_device_requests().await { - let txn_id = request.txn_id.clone(); - - if self.send_to_device(request).await.is_ok() { - self.base_client - .mark_to_device_request_as_sent(&txn_id) - .await; + OutgoingRequests::KeysUpload(request) => { + if let Err(e) = self.keys_upload(&r.request_id, request).await { + warn!("Error while querying device keys {:?}", e); + } + } + OutgoingRequests::ToDeviceRequest(request) => { + if let Ok(resp) = self.send_to_device(request).await { + self.base_client + .mark_request_as_sent(&r.request_id, &resp) + .await + .unwrap(); + } + } } } } @@ -1356,13 +1354,11 @@ impl Client { #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[instrument] - async fn keys_upload(&self) -> Result { - let request = self - .base_client - .keys_for_upload() - .await - .expect("Keys don't need to be uploaded"); - + async fn keys_upload( + &self, + request_id: &Uuid, + request: upload_keys::Request, + ) -> Result { debug!( "Uploading encryption keys device keys: {}, one-time-keys: {}", request.device_keys.is_some(), @@ -1371,8 +1367,9 @@ impl Client { let response = self.send(request).await?; self.base_client - .receive_keys_upload_response(&response) + .mark_request_as_sent(request_id, &response) .await?; + Ok(response) } @@ -1390,33 +1387,20 @@ impl Client { #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[instrument] - async fn keys_query(&self) -> Result { - let mut users_for_query = self - .base_client - .users_for_key_query() - .await - .expect("Keys don't need to be uploaded"); - - debug!( - "Querying device keys device for users: {:?}", - users_for_query - ); - - let mut device_keys: BTreeMap>> = BTreeMap::new(); - - for user in users_for_query.drain() { - device_keys.insert(user, Vec::new()); - } - + async fn keys_query( + &self, + request_id: &Uuid, + request: get_keys::IncomingRequest, + ) -> Result { let request = get_keys::Request { timeout: None, - device_keys, + device_keys: request.device_keys, token: None, }; let response = self.send(request).await?; self.base_client - .receive_keys_query_response(&response) + .mark_request_as_sent(request_id, &response) .await?; Ok(response) diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index f6d15d22..57761618 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -14,7 +14,7 @@ // limitations under the License. #[cfg(feature = "encryption")] -use std::collections::{BTreeMap, HashSet}; +use std::collections::BTreeMap; use std::{ collections::HashMap, fmt, @@ -36,15 +36,12 @@ use matrix_sdk_common::{ identifiers::{RoomId, UserId}, locks::RwLock, push::Ruleset, + uuid::Uuid, Raw, }; #[cfg(feature = "encryption")] use matrix_sdk_common::{ - api::r0::keys::{ - claim_keys::Response as KeysClaimResponse, - get_keys::Response as KeysQueryResponse, - upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse}, - }, + api::r0::keys::claim_keys::Response as KeysClaimResponse, api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, events::room::{ encrypted::EncryptedEventContent, message::MessageEventContent as MsgEventContent, @@ -53,7 +50,8 @@ use matrix_sdk_common::{ }; #[cfg(feature = "encryption")] use matrix_sdk_crypto::{ - CryptoStore, CryptoStoreError, Device, OlmError, OlmMachine, Sas, UserDevices, + CryptoStore, CryptoStoreError, Device, IncomingResponse, OlmError, OlmMachine, OutgoingRequest, + Sas, UserDevices, }; use zeroize::Zeroizing; @@ -1229,18 +1227,6 @@ impl BaseClient { Ok(updated) } - /// Should account or one-time keys be uploaded to the server. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn should_upload_keys(&self) -> bool { - let olm = self.olm.lock().await; - - match &*olm { - Some(o) => o.should_upload_keys().await, - None => false, - } - } - /// Should the client share a group session for the given room. /// /// Returns true if a session needs to be shared before room messages can be @@ -1260,15 +1246,31 @@ impl BaseClient { } } - /// Should users be queried for their device keys. + /// Get the list of outgoing requests that need to be sent out. #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn should_query_keys(&self) -> bool { + pub async fn outgoing_requests(&self) -> Vec { let olm = self.olm.lock().await; match &*olm { - Some(o) => o.should_query_keys().await, - None => false, + Some(o) => o.outgoing_requests().await, + None => vec![], + } + } + + /// Get the list of outgoing requests that need to be sent out. + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + pub async fn mark_request_as_sent<'a>( + &self, + request_id: &Uuid, + response: impl Into>, + ) -> Result<()> { + let olm = self.olm.lock().await; + + match &*olm { + Some(o) => Ok(o.mark_requests_as_sent(request_id, response).await?), + None => Ok(()), } } @@ -1332,53 +1334,6 @@ impl BaseClient { } } - /// Get a tuple of device and one-time keys that need to be uploaded. - /// - /// Returns an empty error if no keys need to be uploaded. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn keys_for_upload(&self) -> Option { - let olm = self.olm.lock().await; - - match &*olm { - Some(o) => o.keys_for_upload().await, - None => None, - } - } - - /// Get the users that we need to query keys for. - /// - /// Returns an empty error if no keys need to be queried. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn users_for_key_query(&self) -> StdResult, ()> { - let olm = self.olm.lock().await; - - match &*olm { - Some(o) => Ok(o.users_for_key_query().await), - None => Err(()), - } - } - - /// Receive a successful keys upload response. - /// - /// # Arguments - /// - /// * `response` - The keys upload response of the request that the client - /// performed. - /// - /// # Panics - /// Panics if the client hasn't been logged in. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) -> Result<()> { - let olm = self.olm.lock().await; - - let o = olm.as_ref().expect("Client isn't logged in."); - o.receive_keys_upload_response(response).await?; - Ok(()) - } - /// Receive a successful keys claim response. /// /// # Arguments @@ -1398,26 +1353,6 @@ impl BaseClient { Ok(()) } - /// Receive a successful keys query response. - /// - /// # Arguments - /// - /// * `response` - The keys query response of the request that the client - /// performed. - /// - /// # Panics - /// Panics if the client hasn't been logged in. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn receive_keys_query_response(&self, response: &KeysQueryResponse) -> Result<()> { - let olm = self.olm.lock().await; - - let o = olm.as_ref().expect("Client isn't logged in."); - o.receive_keys_query_response(response).await?; - // TODO notify our callers of new devices via some callback. - Ok(()) - } - /// Invalidate the currently active outbound group session for the given /// room. /// diff --git a/matrix_sdk_base/src/lib.rs b/matrix_sdk_base/src/lib.rs index 4bf782ce..f96d1728 100644 --- a/matrix_sdk_base/src/lib.rs +++ b/matrix_sdk_base/src/lib.rs @@ -57,7 +57,8 @@ pub use state::{AllRooms, ClientState}; #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub use matrix_sdk_crypto::{ - CryptoStoreError, Device, LocalTrust, ReadOnlyDevice, Sas, UserDevices, + CryptoStoreError, Device, IncomingResponse, LocalTrust, OutgoingRequest, OutgoingRequests, + ReadOnlyDevice, Sas, UserDevices, }; #[cfg(feature = "messages")] diff --git a/matrix_sdk_crypto/src/lib.rs b/matrix_sdk_crypto/src/lib.rs index 21a9b487..5fcc422b 100644 --- a/matrix_sdk_crypto/src/lib.rs +++ b/matrix_sdk_crypto/src/lib.rs @@ -32,6 +32,7 @@ mod error; mod machine; mod memory_stores; mod olm; +mod requests; mod store; #[allow(dead_code)] mod user_identity; @@ -44,6 +45,7 @@ pub use memory_stores::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, Ses pub use olm::{ Account, EncryptionSettings, IdentityKeys, InboundGroupSession, OutboundGroupSession, Session, }; +pub use requests::{IncomingResponse, OutgoingRequest, OutgoingRequests}; #[cfg(feature = "sqlite_cryptostore")] pub use store::sqlite::SqliteStore; pub use store::{CryptoStore, CryptoStoreError}; diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 6c71ba28..21745032 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -27,7 +27,11 @@ use tracing::{debug, error, info, instrument, trace, warn}; use matrix_sdk_common::{ api::r0::{ - keys::{claim_keys, get_keys, upload_keys, OneTimeKey}, + keys::{ + claim_keys, + get_keys::{IncomingRequest as KeysQueryRequest, Response as KeysQueryResponse}, + upload_keys, OneTimeKey, + }, sync::sync_events::Response as SyncResponse, to_device::{ send_event_to_device::IncomingRequest as OwnedToDeviceRequest, DeviceIdOrAllDevices, @@ -57,6 +61,7 @@ use super::{ Account, EncryptionSettings, GroupSessionKey, IdentityKeys, InboundGroupSession, OlmMessage, OutboundGroupSession, }, + requests::{IncomingResponse, OutgoingRequest}, store::{memorystore::MemoryStore, Result as StoreResult}, user_identity::{ MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity, @@ -217,6 +222,48 @@ impl OlmMachine { self.account.identity_keys() } + /// Get the outgoing requests that need to be sent out. + pub async fn outgoing_requests(&self) -> Vec { + let mut requests = Vec::new(); + + if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { + request_id: Uuid::new_v4(), + request: r.into(), + }) { + requests.push(r); + } + + if let Some(r) = self.users_for_key_query().await.map(|r| OutgoingRequest { + request_id: Uuid::new_v4(), + request: r.into(), + }) { + requests.push(r); + } + + requests + } + + /// Mark the request with the given request id as sent. + pub async fn mark_requests_as_sent<'a>( + &self, + request_id: &Uuid, + response: impl Into>, + ) -> OlmResult<()> { + match response.into() { + IncomingResponse::KeysQuery(response) => { + self.receive_keys_query_response(response).await?; + } + IncomingResponse::KeysUpload(response) => { + self.receive_keys_upload_response(response).await?; + } + IncomingResponse::ToDevice(_) => { + self.mark_to_device_request_as_sent(&request_id.to_string()); + } + }; + + Ok(()) + } + /// Should device or one-time keys be uploaded to the server. /// /// This needs to be checked periodically, ideally after every sync request. @@ -241,7 +288,8 @@ impl OlmMachine { /// } /// # }); /// ``` - pub async fn should_upload_keys(&self) -> bool { + #[cfg(test)] + async fn should_upload_keys(&self) -> bool { self.account.should_upload_keys().await } @@ -269,7 +317,7 @@ impl OlmMachine { /// * `response` - The keys upload response of the request that the client /// performed. #[instrument] - pub async fn receive_keys_upload_response( + async fn receive_keys_upload_response( &self, response: &upload_keys::Response, ) -> OlmResult<()> { @@ -507,7 +555,7 @@ impl OlmMachine { /// they are new, one of their properties has changed or they got deleted. async fn handle_cross_singing_keys( &self, - response: &get_keys::Response, + response: &KeysQueryResponse, ) -> StoreResult> { let mut changed = Vec::new(); @@ -613,9 +661,9 @@ impl OlmMachine { /// /// * `response` - The keys query response of the request that the client /// performed. - pub async fn receive_keys_query_response( + async fn receive_keys_query_response( &self, - response: &get_keys::Response, + response: &KeysQueryResponse, ) -> OlmResult<(Vec, Vec)> { let changed_devices = self .handle_devices_from_key_query(&response.device_keys) @@ -636,7 +684,7 @@ impl OlmMachine { /// /// [`receive_keys_upload_response`]: #method.receive_keys_upload_response /// [`OlmMachine`]: struct.OlmMachine.html - pub async fn keys_for_upload(&self) -> Option { + async fn keys_for_upload(&self) -> Option { let (device_keys, one_time_keys) = self.account.keys_for_upload().await?; Some(upload_keys::Request { @@ -1344,22 +1392,34 @@ impl OlmMachine { } } - /// Should the client perform a key query request. - pub async fn should_query_keys(&self) -> bool { - self.store.has_users_for_key_query() - } - - /// Get the set of users that we need to query keys for. + /// Get a key query request if one is needed. /// - /// Returns a hash set of users that need to be queried for keys. + /// Returns a key query reqeust if the client should query E2E keys, + /// otherwise None. /// /// The response of a successful key query requests needs to be passed to /// the [`OlmMachine`] with the [`receive_keys_query_response`]. /// /// [`OlmMachine`]: struct.OlmMachine.html /// [`receive_keys_query_response`]: #method.receive_keys_query_response - pub async fn users_for_key_query(&self) -> HashSet { - self.store.users_for_key_query() + async fn users_for_key_query(&self) -> Option { + let mut users = self.store.users_for_key_query(); + + if users.is_empty() { + None + } else { + let mut device_keys: BTreeMap>> = BTreeMap::new(); + + for user in users.drain() { + device_keys.insert(user, Vec::new()); + } + + Some(KeysQueryRequest { + timeout: None, + device_keys, + token: None, + }) + } } /// Get a specific device of a user. diff --git a/matrix_sdk_crypto/src/requests.rs b/matrix_sdk_crypto/src/requests.rs new file mode 100644 index 00000000..18b6e12d --- /dev/null +++ b/matrix_sdk_crypto/src/requests.rs @@ -0,0 +1,88 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use matrix_sdk_common::{ + api::r0::{ + keys::{ + get_keys::{IncomingRequest as KeysQueryRequest, Response as KeysQueryResponse}, + upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse}, + }, + to_device::send_event_to_device::{ + IncomingRequest as ToDeviceRequest, Response as ToDeviceResponse, + }, + }, + uuid::Uuid, +}; + +/// TODO +#[derive(Debug)] +pub enum OutgoingRequests { + /// TODO + KeysUpload(KeysUploadRequest), + /// TODO + KeysQuery(KeysQueryRequest), + /// TODO + ToDeviceRequest(ToDeviceRequest), +} + +impl From for OutgoingRequests { + fn from(request: KeysQueryRequest) -> Self { + OutgoingRequests::KeysQuery(request) + } +} + +impl From for OutgoingRequests { + fn from(request: KeysUploadRequest) -> Self { + OutgoingRequests::KeysUpload(request) + } +} + +/// TODO +#[derive(Debug)] +pub enum IncomingResponse<'a> { + /// TODO + KeysUpload(&'a KeysUploadResponse), + /// TODO + KeysQuery(&'a KeysQueryResponse), + /// TODO + ToDevice(&'a ToDeviceResponse), +} + +impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> { + fn from(response: &'a KeysUploadResponse) -> Self { + IncomingResponse::KeysUpload(response) + } +} + +impl<'a> From<&'a KeysQueryResponse> for IncomingResponse<'a> { + fn from(response: &'a KeysQueryResponse) -> Self { + IncomingResponse::KeysQuery(response) + } +} + +impl<'a> From<&'a ToDeviceResponse> for IncomingResponse<'a> { + fn from(response: &'a ToDeviceResponse) -> Self { + IncomingResponse::ToDevice(response) + } +} + +/// TODO +#[derive(Debug)] +pub struct OutgoingRequest { + /// The unique id of a request, needs to be passed when receiving a + /// response. + pub request_id: Uuid, + /// TODO + pub request: OutgoingRequests, +}