matrix-sdk: Add automatic key claiming support.

master
Damir Jelić 2020-10-07 14:07:47 +02:00
parent 8ea0035cd0
commit 17d23eb9e5
3 changed files with 28 additions and 25 deletions

View File

@ -105,7 +105,7 @@ use matrix_sdk_common::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
keys::{claim_keys, get_keys, upload_keys}, keys::{get_keys, upload_keys},
to_device::send_event_to_device::{ to_device::send_event_to_device::{
Request as RumaToDeviceRequest, Response as ToDeviceResponse, Request as RumaToDeviceRequest, Response as ToDeviceResponse,
}, },
@ -143,6 +143,9 @@ pub struct Client {
/// flight per room. /// flight per room.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>, group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>,
#[cfg(feature = "encryption")]
/// Lock making sure we're only doing one key claim request at a time.
key_claim_lock: Arc<Mutex<()>>,
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
@ -404,6 +407,8 @@ impl Client {
base_client, base_client,
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
group_session_locks: DashMap::new(), group_session_locks: DashMap::new(),
#[cfg(feature = "encryption")]
key_claim_lock: Arc::new(Mutex::new(())),
}) })
} }
@ -1017,19 +1022,16 @@ impl Client {
let _guard = mutex.lock().await; let _guard = mutex.lock().await;
let missing_sessions = { {
let room = self.base_client.get_joined_room(room_id).await; let room = self.base_client.get_joined_room(room_id).await;
let room = room.as_ref().unwrap().read().await; let room = room.as_ref().unwrap().read().await;
let members = room let mut members = room
.joined_members .joined_members
.keys() .keys()
.chain(room.invited_members.keys()); .chain(room.invited_members.keys());
self.base_client.get_missing_sessions(members).await? self.claim_one_time_keys(&mut members).await?;
}; };
if let Some((request_id, request)) = missing_sessions {
self.claim_one_time_keys(&request_id, request).await?;
}
let response = self.share_group_session(room_id).await; let response = self.share_group_session(room_id).await;
self.group_session_locks.remove(room_id); self.group_session_locks.remove(room_id);
@ -1520,6 +1522,10 @@ impl Client {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
{ {
if let Err(e) = self.claim_one_time_keys(&mut [].iter()).await {
warn!("Error while claiming one-time keys {:?}", e);
}
for r in self.base_client.outgoing_requests().await { for r in self.base_client.outgoing_requests().await {
match r.request() { match r.request() {
OutgoingRequests::KeysQuery(request) => { OutgoingRequests::KeysQuery(request) => {
@ -1582,23 +1588,20 @@ impl Client {
/// ///
/// * `users` - The list of user/device pairs that we should claim keys for. /// * `users` - The list of user/device pairs that we should claim keys for.
/// ///
/// # Panics
///
/// Panics if the client isn't logged in, or if no encryption keys need to
/// be uploaded.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[instrument] #[instrument(skip(users))]
async fn claim_one_time_keys( async fn claim_one_time_keys(&self, users: &mut impl Iterator<Item = &UserId>) -> Result<()> {
&self, let _lock = self.key_claim_lock.lock().await;
request_id: &Uuid,
request: claim_keys::Request, if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? {
) -> Result<claim_keys::Response> { let response = self.send(request).await?;
let response = self.send(request).await?; self.base_client
self.base_client .mark_request_as_sent(&request_id, &response)
.mark_request_as_sent(request_id, &response) .await?;
.await?; }
Ok(response)
Ok(())
} }
/// Share a group session for a room. /// Share a group session for a room.

View File

@ -1291,7 +1291,7 @@ impl BaseClient {
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: impl Iterator<Item = &UserId>, users: &mut impl Iterator<Item = &UserId>,
) -> Result<Option<(Uuid, KeysClaimRequest)>> { ) -> Result<Option<(Uuid, KeysClaimRequest)>> {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;

View File

@ -379,7 +379,7 @@ impl OlmMachine {
/// [`mark_request_as_sent`]: #method.mark_request_as_sent /// [`mark_request_as_sent`]: #method.mark_request_as_sent
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: impl Iterator<Item = &UserId>, users: &mut impl Iterator<Item = &UserId>,
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> { ) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
let mut missing = BTreeMap::new(); let mut missing = BTreeMap::new();
@ -1456,7 +1456,7 @@ pub(crate) mod test {
let alice_device = alice_device_id(); let alice_device = alice_device_id();
let (_, missing_sessions) = machine let (_, missing_sessions) = machine
.get_missing_sessions([alice.clone()].iter()) .get_missing_sessions(&mut [alice.clone()].iter())
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();