matrix-sdk: Add automatic key claiming support.
parent
8ea0035cd0
commit
17d23eb9e5
|
@ -105,7 +105,7 @@ use matrix_sdk_common::{
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
api::r0::{
|
api::r0::{
|
||||||
keys::{claim_keys, get_keys, upload_keys},
|
keys::{get_keys, upload_keys},
|
||||||
to_device::send_event_to_device::{
|
to_device::send_event_to_device::{
|
||||||
Request as RumaToDeviceRequest, Response as ToDeviceResponse,
|
Request as RumaToDeviceRequest, Response as ToDeviceResponse,
|
||||||
},
|
},
|
||||||
|
@ -143,6 +143,9 @@ pub struct Client {
|
||||||
/// flight per room.
|
/// flight per room.
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>,
|
group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>,
|
||||||
|
#[cfg(feature = "encryption")]
|
||||||
|
/// Lock making sure we're only doing one key claim request at a time.
|
||||||
|
key_claim_lock: Arc<Mutex<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(tarpaulin_include))]
|
#[cfg(not(tarpaulin_include))]
|
||||||
|
@ -404,6 +407,8 @@ impl Client {
|
||||||
base_client,
|
base_client,
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
group_session_locks: DashMap::new(),
|
group_session_locks: DashMap::new(),
|
||||||
|
#[cfg(feature = "encryption")]
|
||||||
|
key_claim_lock: Arc::new(Mutex::new(())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1017,19 +1022,16 @@ impl Client {
|
||||||
|
|
||||||
let _guard = mutex.lock().await;
|
let _guard = mutex.lock().await;
|
||||||
|
|
||||||
let missing_sessions = {
|
{
|
||||||
let room = self.base_client.get_joined_room(room_id).await;
|
let room = self.base_client.get_joined_room(room_id).await;
|
||||||
let room = room.as_ref().unwrap().read().await;
|
let room = room.as_ref().unwrap().read().await;
|
||||||
let members = room
|
let mut members = room
|
||||||
.joined_members
|
.joined_members
|
||||||
.keys()
|
.keys()
|
||||||
.chain(room.invited_members.keys());
|
.chain(room.invited_members.keys());
|
||||||
self.base_client.get_missing_sessions(members).await?
|
self.claim_one_time_keys(&mut members).await?;
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some((request_id, request)) = missing_sessions {
|
|
||||||
self.claim_one_time_keys(&request_id, request).await?;
|
|
||||||
}
|
|
||||||
let response = self.share_group_session(room_id).await;
|
let response = self.share_group_session(room_id).await;
|
||||||
|
|
||||||
self.group_session_locks.remove(room_id);
|
self.group_session_locks.remove(room_id);
|
||||||
|
@ -1520,6 +1522,10 @@ impl Client {
|
||||||
|
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
{
|
{
|
||||||
|
if let Err(e) = self.claim_one_time_keys(&mut [].iter()).await {
|
||||||
|
warn!("Error while claiming one-time keys {:?}", e);
|
||||||
|
}
|
||||||
|
|
||||||
for r in self.base_client.outgoing_requests().await {
|
for r in self.base_client.outgoing_requests().await {
|
||||||
match r.request() {
|
match r.request() {
|
||||||
OutgoingRequests::KeysQuery(request) => {
|
OutgoingRequests::KeysQuery(request) => {
|
||||||
|
@ -1582,23 +1588,20 @@ impl Client {
|
||||||
///
|
///
|
||||||
/// * `users` - The list of user/device pairs that we should claim keys for.
|
/// * `users` - The list of user/device pairs that we should claim keys for.
|
||||||
///
|
///
|
||||||
/// # Panics
|
|
||||||
///
|
|
||||||
/// Panics if the client isn't logged in, or if no encryption keys need to
|
|
||||||
/// be uploaded.
|
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||||
#[instrument]
|
#[instrument(skip(users))]
|
||||||
async fn claim_one_time_keys(
|
async fn claim_one_time_keys(&self, users: &mut impl Iterator<Item = &UserId>) -> Result<()> {
|
||||||
&self,
|
let _lock = self.key_claim_lock.lock().await;
|
||||||
request_id: &Uuid,
|
|
||||||
request: claim_keys::Request,
|
if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? {
|
||||||
) -> Result<claim_keys::Response> {
|
let response = self.send(request).await?;
|
||||||
let response = self.send(request).await?;
|
self.base_client
|
||||||
self.base_client
|
.mark_request_as_sent(&request_id, &response)
|
||||||
.mark_request_as_sent(request_id, &response)
|
.await?;
|
||||||
.await?;
|
}
|
||||||
Ok(response)
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Share a group session for a room.
|
/// Share a group session for a room.
|
||||||
|
|
|
@ -1291,7 +1291,7 @@ impl BaseClient {
|
||||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||||
pub async fn get_missing_sessions(
|
pub async fn get_missing_sessions(
|
||||||
&self,
|
&self,
|
||||||
users: impl Iterator<Item = &UserId>,
|
users: &mut impl Iterator<Item = &UserId>,
|
||||||
) -> Result<Option<(Uuid, KeysClaimRequest)>> {
|
) -> Result<Option<(Uuid, KeysClaimRequest)>> {
|
||||||
let olm = self.olm.lock().await;
|
let olm = self.olm.lock().await;
|
||||||
|
|
||||||
|
|
|
@ -379,7 +379,7 @@ impl OlmMachine {
|
||||||
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
||||||
pub async fn get_missing_sessions(
|
pub async fn get_missing_sessions(
|
||||||
&self,
|
&self,
|
||||||
users: impl Iterator<Item = &UserId>,
|
users: &mut impl Iterator<Item = &UserId>,
|
||||||
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
|
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
|
||||||
let mut missing = BTreeMap::new();
|
let mut missing = BTreeMap::new();
|
||||||
|
|
||||||
|
@ -1456,7 +1456,7 @@ pub(crate) mod test {
|
||||||
let alice_device = alice_device_id();
|
let alice_device = alice_device_id();
|
||||||
|
|
||||||
let (_, missing_sessions) = machine
|
let (_, missing_sessions) = machine
|
||||||
.get_missing_sessions([alice.clone()].iter())
|
.get_missing_sessions(&mut [alice.clone()].iter())
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
Loading…
Reference in New Issue