crypto: Initial logic for session unwedging.
parent
6d2e9cfc02
commit
bd0ac703a0
|
@ -48,8 +48,8 @@ pub enum OlmError {
|
||||||
Store(#[from] CryptoStoreError),
|
Store(#[from] CryptoStoreError),
|
||||||
|
|
||||||
/// The session with a device has become corrupted.
|
/// The session with a device has become corrupted.
|
||||||
#[error("decryption failed likely because a Olm session was wedged")]
|
#[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")]
|
||||||
SessionWedged,
|
SessionWedged(UserId, String),
|
||||||
|
|
||||||
/// Encryption failed because the device does not have a valid Olm session
|
/// Encryption failed because the device does not have a valid Olm session
|
||||||
/// with us.
|
/// with us.
|
||||||
|
|
|
@ -31,12 +31,13 @@ use matrix_sdk_common::{
|
||||||
EventType,
|
EventType,
|
||||||
},
|
},
|
||||||
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId},
|
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId},
|
||||||
|
locks::Mutex,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
use crate::olm::InboundGroupSession;
|
use crate::olm::{InboundGroupSession, Session};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use crate::{OlmMachine, ReadOnlyAccount};
|
use crate::{OlmMachine, ReadOnlyAccount};
|
||||||
|
|
||||||
|
@ -89,6 +90,15 @@ impl Device {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the Olm sessions that belong to this device.
|
||||||
|
pub(crate) async fn get_sessions(&self) -> StoreResult<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||||
|
if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
|
||||||
|
self.verification_machine.store.get_sessions(k).await
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the trust state of the device.
|
/// Get the trust state of the device.
|
||||||
pub fn trust_state(&self) -> bool {
|
pub fn trust_state(&self) -> bool {
|
||||||
self.inner
|
self.inner
|
||||||
|
|
|
@ -45,7 +45,7 @@ use matrix_sdk_common::{
|
||||||
#[cfg(feature = "sqlite_cryptostore")]
|
#[cfg(feature = "sqlite_cryptostore")]
|
||||||
use crate::store::sqlite::SqliteStore;
|
use crate::store::sqlite::SqliteStore;
|
||||||
use crate::{
|
use crate::{
|
||||||
error::{EventError, MegolmError, MegolmResult, OlmResult},
|
error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
|
||||||
group_manager::GroupSessionManager,
|
group_manager::GroupSessionManager,
|
||||||
identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities},
|
identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities},
|
||||||
key_request::KeyRequestMachine,
|
key_request::KeyRequestMachine,
|
||||||
|
@ -645,6 +645,8 @@ impl OlmMachine {
|
||||||
.mark_outgoing_request_as_sent(request_id)
|
.mark_outgoing_request_as_sent(request_id)
|
||||||
.await?;
|
.await?;
|
||||||
self.group_session_manager.mark_request_as_sent(request_id);
|
self.group_session_manager.mark_request_as_sent(request_id);
|
||||||
|
self.session_manager
|
||||||
|
.mark_outgoing_request_as_sent(request_id);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -710,8 +712,19 @@ impl OlmMachine {
|
||||||
"Failed to decrypt to-device event from {} {}",
|
"Failed to decrypt to-device event from {} {}",
|
||||||
e.sender, err
|
e.sender, err
|
||||||
);
|
);
|
||||||
// TODO if the session is wedged mark it for
|
|
||||||
// unwedging.
|
if let OlmError::SessionWedged(sender, curve_key) = err {
|
||||||
|
if let Err(e) = self
|
||||||
|
.session_manager
|
||||||
|
.mark_device_as_wedged(&sender, &curve_key)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!(
|
||||||
|
"Couldn't mark device from {} to be unwedged {:?}",
|
||||||
|
sender, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -205,7 +205,10 @@ impl Account {
|
||||||
for sender {} and sender_key {} {:?}",
|
for sender {} and sender_key {} {:?}",
|
||||||
sender, sender_key, e
|
sender, sender_key, e
|
||||||
);
|
);
|
||||||
return Err(OlmError::SessionWedged);
|
return Err(OlmError::SessionWedged(
|
||||||
|
sender.to_owned(),
|
||||||
|
sender_key.to_owned(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -248,7 +251,10 @@ impl Account {
|
||||||
available sessions {} {}",
|
available sessions {} {}",
|
||||||
sender, sender_key
|
sender, sender_key
|
||||||
);
|
);
|
||||||
return Err(OlmError::SessionWedged);
|
return Err(OlmError::SessionWedged(
|
||||||
|
sender.to_owned(),
|
||||||
|
sender_key.to_owned(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
OlmMessage::PreKey(m) => {
|
OlmMessage::PreKey(m) => {
|
||||||
|
@ -265,7 +271,10 @@ impl Account {
|
||||||
from a prekey message: {}",
|
from a prekey message: {}",
|
||||||
sender, sender_key, e
|
sender, sender_key, e
|
||||||
);
|
);
|
||||||
return Err(OlmError::SessionWedged);
|
return Err(OlmError::SessionWedged(
|
||||||
|
sender.to_owned(),
|
||||||
|
sender_key.to_owned(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -16,14 +16,26 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use dashmap::{DashMap, DashSet};
|
use dashmap::{DashMap, DashSet};
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
api::r0::keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
|
api::r0::{
|
||||||
|
keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
|
||||||
|
to_device::DeviceIdOrAllDevices,
|
||||||
|
},
|
||||||
assign,
|
assign,
|
||||||
identifiers::{DeviceIdBox, DeviceKeyAlgorithm, UserId},
|
events::EventType,
|
||||||
|
identifiers::{DeviceId, DeviceIdBox, DeviceKeyAlgorithm, UserId},
|
||||||
uuid::Uuid,
|
uuid::Uuid,
|
||||||
};
|
};
|
||||||
|
use serde_json::{json, value::to_raw_value};
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
use crate::{error::OlmResult, key_request::KeyRequestMachine, olm::Account, store::Store};
|
use crate::{
|
||||||
|
error::OlmResult,
|
||||||
|
key_request::KeyRequestMachine,
|
||||||
|
olm::Account,
|
||||||
|
requests::{OutgoingRequest, ToDeviceRequest},
|
||||||
|
store::{Result as StoreResult, Store},
|
||||||
|
Device,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct SessionManager {
|
pub(crate) struct SessionManager {
|
||||||
|
@ -34,11 +46,14 @@ pub(crate) struct SessionManager {
|
||||||
/// user/device paris will be added to the list of users when
|
/// user/device paris will be added to the list of users when
|
||||||
/// [`get_missing_sessions`](#method.get_missing_sessions) is called.
|
/// [`get_missing_sessions`](#method.get_missing_sessions) is called.
|
||||||
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
||||||
|
wedged_devices: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
||||||
key_request_machine: KeyRequestMachine,
|
key_request_machine: KeyRequestMachine,
|
||||||
|
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
|
const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
account: Account,
|
account: Account,
|
||||||
|
@ -51,9 +66,95 @@ impl SessionManager {
|
||||||
store,
|
store,
|
||||||
key_request_machine,
|
key_request_machine,
|
||||||
users_for_key_claim,
|
users_for_key_claim,
|
||||||
|
wedged_devices: Arc::new(DashMap::new()),
|
||||||
|
outgoing_to_device_requests: Arc::new(DashMap::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Mark the outgoing request as sent.
|
||||||
|
pub fn mark_outgoing_request_as_sent(&self, id: &Uuid) {
|
||||||
|
self.outgoing_to_device_requests.remove(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
|
||||||
|
if let Some(device) = self
|
||||||
|
.store
|
||||||
|
.get_device_from_curve_key(sender, curve_key)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
|
let sessions = device.get_sessions().await?;
|
||||||
|
|
||||||
|
if let Some(sessions) = sessions {
|
||||||
|
let mut sessions = sessions.lock().await;
|
||||||
|
sessions.sort_by_key(|s| s.creation_time.clone());
|
||||||
|
|
||||||
|
let session = sessions.get(0);
|
||||||
|
|
||||||
|
if let Some(session) = session {
|
||||||
|
if session.creation_time.elapsed() > Self::UNWEDGING_INTERVAL {
|
||||||
|
self.wedged_devices
|
||||||
|
.entry(device.user_id().to_owned())
|
||||||
|
.or_insert_with(DashSet::new)
|
||||||
|
.insert(device.device_id().into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn is_device_wedged(&self, device: &Device) -> bool {
|
||||||
|
self.wedged_devices
|
||||||
|
.get(device.user_id())
|
||||||
|
.map(|d| d.contains(device.device_id()))
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the session was created to unwedge a Device.
|
||||||
|
///
|
||||||
|
/// If the device was wedged this will queue up a dummy to-device message.
|
||||||
|
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
|
||||||
|
if self
|
||||||
|
.wedged_devices
|
||||||
|
.get(user_id)
|
||||||
|
.map(|d| d.remove(device_id))
|
||||||
|
.flatten()
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
if let Some(device) = self.store.get_device(user_id, device_id).await? {
|
||||||
|
let content = device.encrypt(EventType::Dummy, json!({})).await?;
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
let mut messages = BTreeMap::new();
|
||||||
|
|
||||||
|
messages
|
||||||
|
.entry(device.user_id().to_owned())
|
||||||
|
.or_insert_with(BTreeMap::new)
|
||||||
|
.insert(
|
||||||
|
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||||
|
to_raw_value(&content)?,
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = OutgoingRequest {
|
||||||
|
request_id: id,
|
||||||
|
request: Arc::new(
|
||||||
|
ToDeviceRequest {
|
||||||
|
event_type: EventType::RoomEncrypted,
|
||||||
|
txn_id: id,
|
||||||
|
messages,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.outgoing_to_device_requests.insert(id, request);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the a key claiming request for the user/device pairs that we are
|
/// Get the a key claiming request for the user/device pairs that we are
|
||||||
/// missing Olm sessions for.
|
/// missing Olm sessions for.
|
||||||
///
|
///
|
||||||
|
@ -189,9 +290,14 @@ impl SessionManager {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO if this session was created because a previous one was
|
|
||||||
// wedged queue up a dummy event to be sent out.
|
|
||||||
self.key_request_machine.retry_keyshare(&user_id, device_id);
|
self.key_request_machine.retry_keyshare(&user_id, device_id);
|
||||||
|
|
||||||
|
if let Err(e) = self.check_if_unwedged(&user_id, device_id).await {
|
||||||
|
error!(
|
||||||
|
"Error while treating an unwedged device {} {} {:?}",
|
||||||
|
user_id, device_id, e
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -63,7 +63,9 @@ use url::ParseError;
|
||||||
use sqlx::Error as SqlxError;
|
use sqlx::Error as SqlxError;
|
||||||
|
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
identifiers::{DeviceId, Error as IdentifierValidationError, RoomId, UserId},
|
identifiers::{
|
||||||
|
DeviceId, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId,
|
||||||
|
},
|
||||||
locks::Mutex,
|
locks::Mutex,
|
||||||
};
|
};
|
||||||
use matrix_sdk_common_macros::async_trait;
|
use matrix_sdk_common_macros::async_trait;
|
||||||
|
@ -118,6 +120,19 @@ impl Store {
|
||||||
self.inner.get_user_devices(user_id).await
|
self.inner.get_user_devices(user_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn get_device_from_curve_key(
|
||||||
|
&self,
|
||||||
|
user_id: &UserId,
|
||||||
|
curve_key: &str,
|
||||||
|
) -> Result<Option<Device>> {
|
||||||
|
self.get_user_devices(user_id).await.map(|d| {
|
||||||
|
d.devices().find(|d| {
|
||||||
|
d.get_key(DeviceKeyAlgorithm::Curve25519)
|
||||||
|
.map_or(false, |k| k == curve_key)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
|
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
|
||||||
let devices = self.inner.get_user_devices(user_id).await?;
|
let devices = self.inner.get_user_devices(user_id).await?;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue