diff --git a/matrix_sdk_crypto/src/error.rs b/matrix_sdk_crypto/src/error.rs index ee4e5f1f..39a554bb 100644 --- a/matrix_sdk_crypto/src/error.rs +++ b/matrix_sdk_crypto/src/error.rs @@ -48,8 +48,8 @@ pub enum OlmError { Store(#[from] CryptoStoreError), /// The session with a device has become corrupted. - #[error("decryption failed likely because a Olm session was wedged")] - SessionWedged, + #[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")] + SessionWedged(UserId, String), /// Encryption failed because the device does not have a valid Olm session /// with us. diff --git a/matrix_sdk_crypto/src/identities/device.rs b/matrix_sdk_crypto/src/identities/device.rs index 73aeb22d..578e7718 100644 --- a/matrix_sdk_crypto/src/identities/device.rs +++ b/matrix_sdk_crypto/src/identities/device.rs @@ -31,12 +31,13 @@ use matrix_sdk_common::{ EventType, }, identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId}, + locks::Mutex, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::warn; -use crate::olm::InboundGroupSession; +use crate::olm::{InboundGroupSession, Session}; #[cfg(test)] use crate::{OlmMachine, ReadOnlyAccount}; @@ -89,6 +90,15 @@ impl Device { .await } + /// Get the Olm sessions that belong to this device. + pub(crate) async fn get_sessions(&self) -> StoreResult>>>> { + if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { + self.verification_machine.store.get_sessions(k).await + } else { + Ok(None) + } + } + /// Get the trust state of the device. pub fn trust_state(&self) -> bool { self.inner diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 10a1fd4c..3e64e48a 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -45,7 +45,7 @@ use matrix_sdk_common::{ #[cfg(feature = "sqlite_cryptostore")] use crate::store::sqlite::SqliteStore; use crate::{ - error::{EventError, MegolmError, MegolmResult, OlmResult}, + error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}, group_manager::GroupSessionManager, identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities}, key_request::KeyRequestMachine, @@ -645,6 +645,8 @@ impl OlmMachine { .mark_outgoing_request_as_sent(request_id) .await?; self.group_session_manager.mark_request_as_sent(request_id); + self.session_manager + .mark_outgoing_request_as_sent(request_id); Ok(()) } @@ -710,8 +712,19 @@ impl OlmMachine { "Failed to decrypt to-device event from {} {}", e.sender, err ); - // TODO if the session is wedged mark it for - // unwedging. + + if let OlmError::SessionWedged(sender, curve_key) = err { + if let Err(e) = self + .session_manager + .mark_device_as_wedged(&sender, &curve_key) + .await + { + error!( + "Couldn't mark device from {} to be unwedged {:?}", + sender, e + ); + } + } continue; } }; diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index 0cf5e569..a6522ba1 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -205,7 +205,10 @@ impl Account { for sender {} and sender_key {} {:?}", sender, sender_key, e ); - return Err(OlmError::SessionWedged); + return Err(OlmError::SessionWedged( + sender.to_owned(), + sender_key.to_owned(), + )); } } } @@ -248,7 +251,10 @@ impl Account { available sessions {} {}", sender, sender_key ); - return Err(OlmError::SessionWedged); + return Err(OlmError::SessionWedged( + sender.to_owned(), + sender_key.to_owned(), + )); } OlmMessage::PreKey(m) => { @@ -265,7 +271,10 @@ impl Account { from a prekey message: {}", sender, sender_key, e ); - return Err(OlmError::SessionWedged); + return Err(OlmError::SessionWedged( + sender.to_owned(), + sender_key.to_owned(), + )); } }; diff --git a/matrix_sdk_crypto/src/session_manager.rs b/matrix_sdk_crypto/src/session_manager.rs index e10a2e5e..792caa81 100644 --- a/matrix_sdk_crypto/src/session_manager.rs +++ b/matrix_sdk_crypto/src/session_manager.rs @@ -16,14 +16,26 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration}; use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ - api::r0::keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, + api::r0::{ + keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, + to_device::DeviceIdOrAllDevices, + }, assign, - identifiers::{DeviceIdBox, DeviceKeyAlgorithm, UserId}, + events::EventType, + identifiers::{DeviceId, DeviceIdBox, DeviceKeyAlgorithm, UserId}, uuid::Uuid, }; +use serde_json::{json, value::to_raw_value}; use tracing::{error, info, warn}; -use crate::{error::OlmResult, key_request::KeyRequestMachine, olm::Account, store::Store}; +use crate::{ + error::OlmResult, + key_request::KeyRequestMachine, + olm::Account, + requests::{OutgoingRequest, ToDeviceRequest}, + store::{Result as StoreResult, Store}, + Device, +}; #[derive(Debug, Clone)] pub(crate) struct SessionManager { @@ -34,11 +46,14 @@ pub(crate) struct SessionManager { /// user/device paris will be added to the list of users when /// [`get_missing_sessions`](#method.get_missing_sessions) is called. users_for_key_claim: Arc>>, + wedged_devices: Arc>>, key_request_machine: KeyRequestMachine, + outgoing_to_device_requests: Arc>, } impl SessionManager { const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10); + const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60); pub fn new( account: Account, @@ -51,9 +66,95 @@ impl SessionManager { store, key_request_machine, users_for_key_claim, + wedged_devices: Arc::new(DashMap::new()), + outgoing_to_device_requests: Arc::new(DashMap::new()), } } + /// Mark the outgoing request as sent. + pub fn mark_outgoing_request_as_sent(&self, id: &Uuid) { + self.outgoing_to_device_requests.remove(id); + } + + pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> { + if let Some(device) = self + .store + .get_device_from_curve_key(sender, curve_key) + .await? + { + let sessions = device.get_sessions().await?; + + if let Some(sessions) = sessions { + let mut sessions = sessions.lock().await; + sessions.sort_by_key(|s| s.creation_time.clone()); + + let session = sessions.get(0); + + if let Some(session) = session { + if session.creation_time.elapsed() > Self::UNWEDGING_INTERVAL { + self.wedged_devices + .entry(device.user_id().to_owned()) + .or_insert_with(DashSet::new) + .insert(device.device_id().into()); + } + } + } + } + + Ok(()) + } + + #[allow(dead_code)] + pub fn is_device_wedged(&self, device: &Device) -> bool { + self.wedged_devices + .get(device.user_id()) + .map(|d| d.contains(device.device_id())) + .unwrap_or(false) + } + + /// Check if the session was created to unwedge a Device. + /// + /// If the device was wedged this will queue up a dummy to-device message. + async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> { + if self + .wedged_devices + .get(user_id) + .map(|d| d.remove(device_id)) + .flatten() + .is_some() + { + if let Some(device) = self.store.get_device(user_id, device_id).await? { + let content = device.encrypt(EventType::Dummy, json!({})).await?; + let id = Uuid::new_v4(); + let mut messages = BTreeMap::new(); + + messages + .entry(device.user_id().to_owned()) + .or_insert_with(BTreeMap::new) + .insert( + DeviceIdOrAllDevices::DeviceId(device.device_id().into()), + to_raw_value(&content)?, + ); + + let request = OutgoingRequest { + request_id: id, + request: Arc::new( + ToDeviceRequest { + event_type: EventType::RoomEncrypted, + txn_id: id, + messages, + } + .into(), + ), + }; + + self.outgoing_to_device_requests.insert(id, request); + } + } + + Ok(()) + } + /// Get the a key claiming request for the user/device pairs that we are /// missing Olm sessions for. /// @@ -189,9 +290,14 @@ impl SessionManager { continue; } - // TODO if this session was created because a previous one was - // wedged queue up a dummy event to be sent out. self.key_request_machine.retry_keyshare(&user_id, device_id); + + if let Err(e) = self.check_if_unwedged(&user_id, device_id).await { + error!( + "Error while treating an unwedged device {} {} {:?}", + user_id, device_id, e + ); + } } } Ok(()) diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 83c2a0f9..767bbae4 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -63,7 +63,9 @@ use url::ParseError; use sqlx::Error as SqlxError; use matrix_sdk_common::{ - identifiers::{DeviceId, Error as IdentifierValidationError, RoomId, UserId}, + identifiers::{ + DeviceId, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId, + }, locks::Mutex, }; use matrix_sdk_common_macros::async_trait; @@ -118,6 +120,19 @@ impl Store { self.inner.get_user_devices(user_id).await } + pub async fn get_device_from_curve_key( + &self, + user_id: &UserId, + curve_key: &str, + ) -> Result> { + self.get_user_devices(user_id).await.map(|d| { + d.devices().find(|d| { + d.get_key(DeviceKeyAlgorithm::Curve25519) + .map_or(false, |k| k == curve_key) + }) + }) + } + pub async fn get_user_devices(&self, user_id: &UserId) -> Result { let devices = self.inner.get_user_devices(user_id).await?;