crypto: Initial logic for session unwedging.

master
Damir Jelić 2020-10-09 15:39:35 +02:00
parent 6d2e9cfc02
commit bd0ac703a0
6 changed files with 168 additions and 15 deletions

View File

@ -48,8 +48,8 @@ pub enum OlmError {
Store(#[from] CryptoStoreError), Store(#[from] CryptoStoreError),
/// The session with a device has become corrupted. /// The session with a device has become corrupted.
#[error("decryption failed likely because a Olm session was wedged")] #[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")]
SessionWedged, SessionWedged(UserId, String),
/// Encryption failed because the device does not have a valid Olm session /// Encryption failed because the device does not have a valid Olm session
/// with us. /// with us.

View File

@ -31,12 +31,13 @@ use matrix_sdk_common::{
EventType, EventType,
}, },
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId}, identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId},
locks::Mutex,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::warn; use tracing::warn;
use crate::olm::InboundGroupSession; use crate::olm::{InboundGroupSession, Session};
#[cfg(test)] #[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount}; use crate::{OlmMachine, ReadOnlyAccount};
@ -89,6 +90,15 @@ impl Device {
.await .await
} }
/// Get the Olm sessions that belong to this device.
pub(crate) async fn get_sessions(&self) -> StoreResult<Option<Arc<Mutex<Vec<Session>>>>> {
if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
self.verification_machine.store.get_sessions(k).await
} else {
Ok(None)
}
}
/// Get the trust state of the device. /// Get the trust state of the device.
pub fn trust_state(&self) -> bool { pub fn trust_state(&self) -> bool {
self.inner self.inner

View File

@ -45,7 +45,7 @@ use matrix_sdk_common::{
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
use crate::store::sqlite::SqliteStore; use crate::store::sqlite::SqliteStore;
use crate::{ use crate::{
error::{EventError, MegolmError, MegolmResult, OlmResult}, error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
group_manager::GroupSessionManager, group_manager::GroupSessionManager,
identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities}, identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities},
key_request::KeyRequestMachine, key_request::KeyRequestMachine,
@ -645,6 +645,8 @@ impl OlmMachine {
.mark_outgoing_request_as_sent(request_id) .mark_outgoing_request_as_sent(request_id)
.await?; .await?;
self.group_session_manager.mark_request_as_sent(request_id); self.group_session_manager.mark_request_as_sent(request_id);
self.session_manager
.mark_outgoing_request_as_sent(request_id);
Ok(()) Ok(())
} }
@ -710,8 +712,19 @@ impl OlmMachine {
"Failed to decrypt to-device event from {} {}", "Failed to decrypt to-device event from {} {}",
e.sender, err e.sender, err
); );
// TODO if the session is wedged mark it for
// unwedging. if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self
.session_manager
.mark_device_as_wedged(&sender, &curve_key)
.await
{
error!(
"Couldn't mark device from {} to be unwedged {:?}",
sender, e
);
}
}
continue; continue;
} }
}; };

View File

@ -205,7 +205,10 @@ impl Account {
for sender {} and sender_key {} {:?}", for sender {} and sender_key {} {:?}",
sender, sender_key, e sender, sender_key, e
); );
return Err(OlmError::SessionWedged); return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
} }
} }
} }
@ -248,7 +251,10 @@ impl Account {
available sessions {} {}", available sessions {} {}",
sender, sender_key sender, sender_key
); );
return Err(OlmError::SessionWedged); return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
} }
OlmMessage::PreKey(m) => { OlmMessage::PreKey(m) => {
@ -265,7 +271,10 @@ impl Account {
from a prekey message: {}", from a prekey message: {}",
sender, sender_key, e sender, sender_key, e
); );
return Err(OlmError::SessionWedged); return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
} }
}; };

View File

@ -16,14 +16,26 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration};
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, api::r0::{
keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
to_device::DeviceIdOrAllDevices,
},
assign, assign,
identifiers::{DeviceIdBox, DeviceKeyAlgorithm, UserId}, events::EventType,
identifiers::{DeviceId, DeviceIdBox, DeviceKeyAlgorithm, UserId},
uuid::Uuid, uuid::Uuid,
}; };
use serde_json::{json, value::to_raw_value};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::{error::OlmResult, key_request::KeyRequestMachine, olm::Account, store::Store}; use crate::{
error::OlmResult,
key_request::KeyRequestMachine,
olm::Account,
requests::{OutgoingRequest, ToDeviceRequest},
store::{Result as StoreResult, Store},
Device,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct SessionManager { pub(crate) struct SessionManager {
@ -34,11 +46,14 @@ pub(crate) struct SessionManager {
/// user/device paris will be added to the list of users when /// user/device paris will be added to the list of users when
/// [`get_missing_sessions`](#method.get_missing_sessions) is called. /// [`get_missing_sessions`](#method.get_missing_sessions) is called.
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>, users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
wedged_devices: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
key_request_machine: KeyRequestMachine, key_request_machine: KeyRequestMachine,
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
} }
impl SessionManager { impl SessionManager {
const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10); const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
pub fn new( pub fn new(
account: Account, account: Account,
@ -51,9 +66,95 @@ impl SessionManager {
store, store,
key_request_machine, key_request_machine,
users_for_key_claim, users_for_key_claim,
wedged_devices: Arc::new(DashMap::new()),
outgoing_to_device_requests: Arc::new(DashMap::new()),
} }
} }
/// Mark the outgoing request as sent.
pub fn mark_outgoing_request_as_sent(&self, id: &Uuid) {
self.outgoing_to_device_requests.remove(id);
}
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
if let Some(device) = self
.store
.get_device_from_curve_key(sender, curve_key)
.await?
{
let sessions = device.get_sessions().await?;
if let Some(sessions) = sessions {
let mut sessions = sessions.lock().await;
sessions.sort_by_key(|s| s.creation_time.clone());
let session = sessions.get(0);
if let Some(session) = session {
if session.creation_time.elapsed() > Self::UNWEDGING_INTERVAL {
self.wedged_devices
.entry(device.user_id().to_owned())
.or_insert_with(DashSet::new)
.insert(device.device_id().into());
}
}
}
}
Ok(())
}
#[allow(dead_code)]
pub fn is_device_wedged(&self, device: &Device) -> bool {
self.wedged_devices
.get(device.user_id())
.map(|d| d.contains(device.device_id()))
.unwrap_or(false)
}
/// Check if the session was created to unwedge a Device.
///
/// If the device was wedged this will queue up a dummy to-device message.
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self
.wedged_devices
.get(user_id)
.map(|d| d.remove(device_id))
.flatten()
.is_some()
{
if let Some(device) = self.store.get_device(user_id, device_id).await? {
let content = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
messages
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest {
request_id: id,
request: Arc::new(
ToDeviceRequest {
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
}
.into(),
),
};
self.outgoing_to_device_requests.insert(id, request);
}
}
Ok(())
}
/// Get the a key claiming request for the user/device pairs that we are /// Get the a key claiming request for the user/device pairs that we are
/// missing Olm sessions for. /// missing Olm sessions for.
/// ///
@ -189,9 +290,14 @@ impl SessionManager {
continue; continue;
} }
// TODO if this session was created because a previous one was
// wedged queue up a dummy event to be sent out.
self.key_request_machine.retry_keyshare(&user_id, device_id); self.key_request_machine.retry_keyshare(&user_id, device_id);
if let Err(e) = self.check_if_unwedged(&user_id, device_id).await {
error!(
"Error while treating an unwedged device {} {} {:?}",
user_id, device_id, e
);
}
} }
} }
Ok(()) Ok(())

View File

@ -63,7 +63,9 @@ use url::ParseError;
use sqlx::Error as SqlxError; use sqlx::Error as SqlxError;
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceId, Error as IdentifierValidationError, RoomId, UserId}, identifiers::{
DeviceId, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId,
},
locks::Mutex, locks::Mutex,
}; };
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
@ -118,6 +120,19 @@ impl Store {
self.inner.get_user_devices(user_id).await self.inner.get_user_devices(user_id).await
} }
pub async fn get_device_from_curve_key(
&self,
user_id: &UserId,
curve_key: &str,
) -> Result<Option<Device>> {
self.get_user_devices(user_id).await.map(|d| {
d.devices().find(|d| {
d.get_key(DeviceKeyAlgorithm::Curve25519)
.map_or(false, |k| k == curve_key)
})
})
}
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> { pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?; let devices = self.inner.get_user_devices(user_id).await?;