crypto: Streamline the key claiming so we use the new mark request as sent method.

master
Damir Jelić 2020-08-21 14:40:49 +02:00
parent 93e1967119
commit e38bfc64f4
5 changed files with 68 additions and 78 deletions

View File

@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(feature = "encryption")]
use std::collections::BTreeMap;
use std::{ use std::{
collections::HashMap, collections::HashMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
@ -76,7 +74,6 @@ use matrix_sdk_common::{
Response as ToDeviceResponse, Response as ToDeviceResponse,
}, },
}, },
identifiers::DeviceKeyAlgorithm,
locks::Mutex, locks::Mutex,
}; };
@ -970,10 +967,9 @@ impl Client {
self.base_client.get_missing_sessions(members).await? self.base_client.get_missing_sessions(members).await?
}; };
if !missing_sessions.is_empty() { if let Some((request_id, request)) = missing_sessions {
self.claim_one_time_keys(missing_sessions).await?; self.claim_one_time_keys(&request_id, request).await?;
} }
let response = self.share_group_session(room_id).await; let response = self.share_group_session(room_id).await;
self.group_session_locks.remove(room_id); self.group_session_locks.remove(room_id);
@ -1301,16 +1297,12 @@ impl Client {
#[instrument] #[instrument]
async fn claim_one_time_keys( async fn claim_one_time_keys(
&self, &self,
one_time_keys: BTreeMap<UserId, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>, request_id: &Uuid,
request: claim_keys::Request,
) -> Result<claim_keys::Response> { ) -> Result<claim_keys::Response> {
let request = claim_keys::Request {
timeout: None,
one_time_keys,
};
let response = self.send(request).await?; let response = self.send(request).await?;
self.base_client self.base_client
.receive_keys_claim_response(&response) .mark_request_as_sent(request_id, &response)
.await?; .await?;
Ok(response) Ok(response)
} }

View File

@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(feature = "encryption")]
use std::collections::BTreeMap;
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt, fmt,
@ -41,12 +39,12 @@ use matrix_sdk_common::{
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::claim_keys::Response as KeysClaimResponse, api::r0::keys::claim_keys::Request as KeysClaimRequest,
api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest,
events::room::{ events::room::{
encrypted::EncryptedEventContent, message::MessageEventContent as MsgEventContent, encrypted::EncryptedEventContent, message::MessageEventContent as MsgEventContent,
}, },
identifiers::{DeviceId, DeviceKeyAlgorithm}, identifiers::DeviceId,
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_crypto::{ use matrix_sdk_crypto::{
@ -1282,12 +1280,12 @@ impl BaseClient {
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
) -> Result<BTreeMap<UserId, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>> { ) -> Result<Option<(Uuid, KeysClaimRequest)>> {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &*olm { match &*olm {
Some(o) => Ok(o.get_missing_sessions(users).await?), Some(o) => Ok(o.get_missing_sessions(users).await?),
None => Ok(BTreeMap::new()), None => Ok(None),
} }
} }
@ -1334,25 +1332,6 @@ impl BaseClient {
} }
} }
/// Receive a successful keys claim response.
///
/// # Arguments
///
/// * `response` - The keys claim response of the request that the client
/// performed.
///
/// # Panics
/// Panics if the client hasn't been logged in.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> Result<()> {
let olm = self.olm.lock().await;
let o = olm.as_ref().expect("Client isn't logged in.");
o.receive_keys_claim_response(response).await?;
Ok(())
}
/// Invalidate the currently active outbound group session for the given /// Invalidate the currently active outbound group session for the given
/// room. /// room.
/// ///

View File

@ -40,7 +40,7 @@ mod verification;
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices}; pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
pub use error::{MegolmError, OlmError}; pub use error::{MegolmError, OlmError};
pub use machine::{OlmMachine, OneTimeKeys}; pub use machine::OlmMachine;
pub use memory_stores::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore}; pub use memory_stores::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore};
pub use olm::{ pub use olm::{
Account, EncryptionSettings, IdentityKeys, InboundGroupSession, OutboundGroupSession, Session, Account, EncryptionSettings, IdentityKeys, InboundGroupSession, OutboundGroupSession, Session,

View File

@ -19,6 +19,7 @@ use std::{
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
mem, mem,
sync::Arc, sync::Arc,
time::Duration,
}; };
use dashmap::DashMap; use dashmap::DashMap;
@ -28,9 +29,9 @@ use tracing::{debug, error, info, instrument, trace, warn};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
keys::{ keys::{
claim_keys, claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
get_keys::{IncomingRequest as KeysQueryRequest, Response as KeysQueryResponse}, get_keys::{IncomingRequest as KeysQueryRequest, Response as KeysQueryResponse},
upload_keys, OneTimeKey, upload_keys,
}, },
sync::sync_events::Response as SyncResponse, sync::sync_events::Response as SyncResponse,
to_device::{ to_device::{
@ -45,9 +46,7 @@ use matrix_sdk_common::{
room_key_request::RoomKeyRequestEventContent, room_key_request::RoomKeyRequestEventContent,
AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent,
}, },
identifiers::{ identifiers::{DeviceId, DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId, UserId},
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId,
},
uuid::Uuid, uuid::Uuid,
Raw, Raw,
}; };
@ -71,11 +70,6 @@ use super::{
CryptoStore, CryptoStore,
}; };
/// A map from the algorithm and device id to a one-time key.
///
/// These keys need to be periodically uploaded to the server.
pub type OneTimeKeys = BTreeMap<DeviceKeyId, OneTimeKey>;
/// State machine implementation of the Olm/Megolm encryption protocol used for /// State machine implementation of the Olm/Megolm encryption protocol used for
/// Matrix end to end encryption. /// Matrix end to end encryption.
#[derive(Clone)] #[derive(Clone)]
@ -109,6 +103,7 @@ impl std::fmt::Debug for OlmMachine {
impl OlmMachine { impl OlmMachine {
const MAX_TO_DEVICE_MESSAGES: usize = 20; const MAX_TO_DEVICE_MESSAGES: usize = 20;
const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
/// Create a new memory based OlmMachine. /// Create a new memory based OlmMachine.
/// ///
@ -250,11 +245,14 @@ impl OlmMachine {
response: impl Into<IncomingResponse<'a>>, response: impl Into<IncomingResponse<'a>>,
) -> OlmResult<()> { ) -> OlmResult<()> {
match response.into() { match response.into() {
IncomingResponse::KeysUpload(response) => {
self.receive_keys_upload_response(response).await?;
}
IncomingResponse::KeysQuery(response) => { IncomingResponse::KeysQuery(response) => {
self.receive_keys_query_response(response).await?; self.receive_keys_query_response(response).await?;
} }
IncomingResponse::KeysUpload(response) => { IncomingResponse::KeysClaim(response) => {
self.receive_keys_upload_response(response).await?; self.receive_keys_claim_response(response).await?;
} }
IncomingResponse::ToDevice(_) => { IncomingResponse::ToDevice(_) => {
self.mark_to_device_request_as_sent(&request_id.to_string()); self.mark_to_device_request_as_sent(&request_id.to_string());
@ -344,12 +342,10 @@ impl OlmMachine {
Ok(()) Ok(())
} }
/// Get the user/device pairs for which no Olm session exists. /// Get the a key claiming request for the user/device pairs that we are
/// missing Olm sessions for.
/// ///
/// Returns a map from the user id, to a map from the device id to a key /// Returns None if no key claiming request needs to be sent out.
/// algorithm.
///
/// This can be used to make a key claiming request to the server.
/// ///
/// Sessions need to be established between devices so group sessions for a /// Sessions need to be established between devices so group sessions for a
/// room can be shared with them. /// room can be shared with them.
@ -357,18 +353,18 @@ impl OlmMachine {
/// This should be called every time a group session needs to be shared. /// This should be called every time a group session needs to be shared.
/// ///
/// The response of a successful key claiming requests needs to be passed to /// The response of a successful key claiming requests needs to be passed to
/// the `OlmMachine` with the [`receive_keys_claim_response`]. /// the `OlmMachine` with the [`mark_requests_as_sent`].
/// ///
/// # Arguments /// # Arguments
/// ///
/// `users` - The list of users that we should check if we lack a session /// `users` - The list of users that we should check if we lack a session
/// with one of their devices. /// with one of their devices.
/// ///
/// [`receive_keys_claim_response`]: #method.receive_keys_claim_response /// [`mark_requests_as_sent`]: #method.mark_requests_as_sent
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
) -> OlmResult<BTreeMap<UserId, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>> { ) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
let mut missing = BTreeMap::new(); let mut missing = BTreeMap::new();
for user_id in users { for user_id in users {
@ -403,7 +399,17 @@ impl OlmMachine {
} }
} }
Ok(missing) if missing.is_empty() {
Ok(None)
} else {
Ok(Some((
Uuid::new_v4(),
KeysClaimRequest {
timeout: Some(OlmMachine::KEY_CLAIM_TIMEOUT),
one_time_keys: missing,
},
)))
}
} }
/// Receive a successful key claim response and create new Olm sessions with /// Receive a successful key claim response and create new Olm sessions with
@ -412,10 +418,7 @@ impl OlmMachine {
/// # Arguments /// # Arguments
/// ///
/// * `response` - The response containing the claimed one-time keys. /// * `response` - The response containing the claimed one-time keys.
pub async fn receive_keys_claim_response( async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
&self,
response: &claim_keys::Response,
) -> OlmResult<()> {
// TODO log the failures here // TODO log the failures here
for (user_id, user_devices) in &response.one_time_keys { for (user_id, user_devices) in &response.one_time_keys {
@ -1523,7 +1526,6 @@ impl OlmMachine {
pub(crate) mod test { pub(crate) mod test {
static USER_ID: &str = "@bob:example.org"; static USER_ID: &str = "@bob:example.org";
use matrix_sdk_common::js_int::uint;
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
@ -1536,13 +1538,15 @@ pub(crate) mod test {
use tempfile::tempdir; use tempfile::tempdir;
use crate::{ use crate::{
machine::{OlmMachine, OneTimeKeys}, machine::OlmMachine, verification::test::request_to_event, verify_json, EncryptionSettings,
verification::test::request_to_event, ReadOnlyDevice,
verify_json, EncryptionSettings, ReadOnlyDevice,
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{keys, to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest}, api::r0::{
keys::{claim_keys, get_keys, upload_keys, OneTimeKey},
to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest,
},
events::{ events::{
room::{ room::{
encrypted::EncryptedEventContent, encrypted::EncryptedEventContent,
@ -1558,6 +1562,11 @@ pub(crate) mod test {
}; };
use matrix_sdk_test::test_json; use matrix_sdk_test::test_json;
/// These keys need to be periodically uploaded to the server.
type OneTimeKeys = BTreeMap<DeviceKeyId, OneTimeKey>;
use matrix_sdk_common::js_int::uint;
fn alice_id() -> UserId { fn alice_id() -> UserId {
user_id!("@alice:example.org") user_id!("@alice:example.org")
} }
@ -1577,14 +1586,14 @@ pub(crate) mod test {
.unwrap() .unwrap()
} }
fn keys_upload_response() -> keys::upload_keys::Response { fn keys_upload_response() -> upload_keys::Response {
let data = response_from_file(&test_json::KEYS_UPLOAD); let data = response_from_file(&test_json::KEYS_UPLOAD);
keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response") upload_keys::Response::try_from(data).expect("Can't parse the keys upload response")
} }
fn keys_query_response() -> keys::get_keys::Response { fn keys_query_response() -> get_keys::Response {
let data = response_from_file(&test_json::KEYS_QUERY); let data = response_from_file(&test_json::KEYS_QUERY);
keys::get_keys::Response::try_from(data).expect("Can't parse the keys upload response") get_keys::Response::try_from(data).expect("Can't parse the keys upload response")
} }
fn to_device_requests_to_content(requests: Vec<OwnedToDeviceRequest>) -> EncryptedEventContent { fn to_device_requests_to_content(requests: Vec<OwnedToDeviceRequest>) -> EncryptedEventContent {
@ -1662,7 +1671,7 @@ pub(crate) mod test {
let mut one_time_keys = BTreeMap::new(); let mut one_time_keys = BTreeMap::new();
one_time_keys.insert(bob.user_id.clone(), bob_keys); one_time_keys.insert(bob.user_id.clone(), bob_keys);
let response = keys::claim_keys::Response { let response = claim_keys::Response {
failures: BTreeMap::new(), failures: BTreeMap::new(),
one_time_keys, one_time_keys,
}; };
@ -1906,13 +1915,14 @@ pub(crate) mod test {
let alice = alice_id(); let alice = alice_id();
let alice_device = alice_device_id(); let alice_device = alice_device_id();
let missing_sessions = machine let (_, missing_sessions) = machine
.get_missing_sessions([alice.clone()].iter()) .get_missing_sessions([alice.clone()].iter())
.await .await
.unwrap()
.unwrap(); .unwrap();
assert!(missing_sessions.contains_key(&alice)); assert!(missing_sessions.one_time_keys.contains_key(&alice));
let user_sessions = missing_sessions.get(&alice).unwrap(); let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap();
assert!(user_sessions.contains_key(&alice_device)); assert!(user_sessions.contains_key(&alice_device));
} }
@ -1930,7 +1940,7 @@ pub(crate) mod test {
let mut one_time_keys = BTreeMap::new(); let mut one_time_keys = BTreeMap::new();
one_time_keys.insert(bob_machine.user_id.clone(), bob_keys); one_time_keys.insert(bob_machine.user_id.clone(), bob_keys);
let response = keys::claim_keys::Response { let response = claim_keys::Response {
failures: BTreeMap::new(), failures: BTreeMap::new(),
one_time_keys, one_time_keys,
}; };

View File

@ -15,6 +15,7 @@
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
keys::{ keys::{
claim_keys::Response as KeysClaimResponse,
get_keys::{IncomingRequest as KeysQueryRequest, Response as KeysQueryResponse}, get_keys::{IncomingRequest as KeysQueryRequest, Response as KeysQueryResponse},
upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse}, upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse},
}, },
@ -57,6 +58,8 @@ pub enum IncomingResponse<'a> {
KeysQuery(&'a KeysQueryResponse), KeysQuery(&'a KeysQueryResponse),
/// TODO /// TODO
ToDevice(&'a ToDeviceResponse), ToDevice(&'a ToDeviceResponse),
///
KeysClaim(&'a KeysClaimResponse),
} }
impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> { impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> {
@ -77,6 +80,12 @@ impl<'a> From<&'a ToDeviceResponse> for IncomingResponse<'a> {
} }
} }
impl<'a> From<&'a KeysClaimResponse> for IncomingResponse<'a> {
fn from(response: &'a KeysClaimResponse) -> Self {
IncomingResponse::KeysClaim(response)
}
}
/// TODO /// TODO
#[derive(Debug)] #[derive(Debug)]
pub struct OutgoingRequest { pub struct OutgoingRequest {