matrix-sdk: Add automatic key claiming support.
parent
8ea0035cd0
commit
17d23eb9e5
|
@ -105,7 +105,7 @@ use matrix_sdk_common::{
|
|||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_common::{
|
||||
api::r0::{
|
||||
keys::{claim_keys, get_keys, upload_keys},
|
||||
keys::{get_keys, upload_keys},
|
||||
to_device::send_event_to_device::{
|
||||
Request as RumaToDeviceRequest, Response as ToDeviceResponse,
|
||||
},
|
||||
|
@ -143,6 +143,9 @@ pub struct Client {
|
|||
/// flight per room.
|
||||
#[cfg(feature = "encryption")]
|
||||
group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>,
|
||||
#[cfg(feature = "encryption")]
|
||||
/// Lock making sure we're only doing one key claim request at a time.
|
||||
key_claim_lock: Arc<Mutex<()>>,
|
||||
}
|
||||
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
|
@ -404,6 +407,8 @@ impl Client {
|
|||
base_client,
|
||||
#[cfg(feature = "encryption")]
|
||||
group_session_locks: DashMap::new(),
|
||||
#[cfg(feature = "encryption")]
|
||||
key_claim_lock: Arc::new(Mutex::new(())),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1017,19 +1022,16 @@ impl Client {
|
|||
|
||||
let _guard = mutex.lock().await;
|
||||
|
||||
let missing_sessions = {
|
||||
{
|
||||
let room = self.base_client.get_joined_room(room_id).await;
|
||||
let room = room.as_ref().unwrap().read().await;
|
||||
let members = room
|
||||
let mut members = room
|
||||
.joined_members
|
||||
.keys()
|
||||
.chain(room.invited_members.keys());
|
||||
self.base_client.get_missing_sessions(members).await?
|
||||
self.claim_one_time_keys(&mut members).await?;
|
||||
};
|
||||
|
||||
if let Some((request_id, request)) = missing_sessions {
|
||||
self.claim_one_time_keys(&request_id, request).await?;
|
||||
}
|
||||
let response = self.share_group_session(room_id).await;
|
||||
|
||||
self.group_session_locks.remove(room_id);
|
||||
|
@ -1520,6 +1522,10 @@ impl Client {
|
|||
|
||||
#[cfg(feature = "encryption")]
|
||||
{
|
||||
if let Err(e) = self.claim_one_time_keys(&mut [].iter()).await {
|
||||
warn!("Error while claiming one-time keys {:?}", e);
|
||||
}
|
||||
|
||||
for r in self.base_client.outgoing_requests().await {
|
||||
match r.request() {
|
||||
OutgoingRequests::KeysQuery(request) => {
|
||||
|
@ -1582,23 +1588,20 @@ impl Client {
|
|||
///
|
||||
/// * `users` - The list of user/device pairs that we should claim keys for.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the client isn't logged in, or if no encryption keys need to
|
||||
/// be uploaded.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
#[instrument]
|
||||
async fn claim_one_time_keys(
|
||||
&self,
|
||||
request_id: &Uuid,
|
||||
request: claim_keys::Request,
|
||||
) -> Result<claim_keys::Response> {
|
||||
#[instrument(skip(users))]
|
||||
async fn claim_one_time_keys(&self, users: &mut impl Iterator<Item = &UserId>) -> Result<()> {
|
||||
let _lock = self.key_claim_lock.lock().await;
|
||||
|
||||
if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? {
|
||||
let response = self.send(request).await?;
|
||||
self.base_client
|
||||
.mark_request_as_sent(request_id, &response)
|
||||
.mark_request_as_sent(&request_id, &response)
|
||||
.await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Share a group session for a room.
|
||||
|
|
|
@ -1291,7 +1291,7 @@ impl BaseClient {
|
|||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
pub async fn get_missing_sessions(
|
||||
&self,
|
||||
users: impl Iterator<Item = &UserId>,
|
||||
users: &mut impl Iterator<Item = &UserId>,
|
||||
) -> Result<Option<(Uuid, KeysClaimRequest)>> {
|
||||
let olm = self.olm.lock().await;
|
||||
|
||||
|
|
|
@ -379,7 +379,7 @@ impl OlmMachine {
|
|||
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
||||
pub async fn get_missing_sessions(
|
||||
&self,
|
||||
users: impl Iterator<Item = &UserId>,
|
||||
users: &mut impl Iterator<Item = &UserId>,
|
||||
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
|
||||
let mut missing = BTreeMap::new();
|
||||
|
||||
|
@ -1456,7 +1456,7 @@ pub(crate) mod test {
|
|||
let alice_device = alice_device_id();
|
||||
|
||||
let (_, missing_sessions) = machine
|
||||
.get_missing_sessions([alice.clone()].iter())
|
||||
.get_missing_sessions(&mut [alice.clone()].iter())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
Loading…
Reference in New Issue