crypto: Rework the cryptostore.

This modifies the cryptostore and storage logic in two ways:
    * The cryptostore trait has only one main save method.
    * The receive_sync method tries to save all the objects in one
    `save_changes()` call.

This means that all the changes a sync makes get commited to the store
in one transaction, leaving us in a consistent state.

This also means that we can pass the Changes struct the receive sync
method collects to our caller if the caller wishes to store the room
state and crypto state changes in a single transaction.
This commit is contained in:
Damir Jelić 2020-10-20 17:19:37 +02:00
parent 728d80ed06
commit 7cab7cadc9
13 changed files with 711 additions and 477 deletions

View file

@ -971,8 +971,6 @@ impl BaseClient {
return Ok(());
}
*self.sync_token.write().await = Some(response.next_batch.clone());
#[cfg(feature = "encryption")]
{
let olm = self.olm.lock().await;
@ -982,10 +980,12 @@ impl BaseClient {
// decryptes to-device events, but leaves room events alone.
// This makes sure that we have the deryption keys for the room
// events at hand.
o.receive_sync_response(response).await;
o.receive_sync_response(response).await?;
}
}
*self.sync_token.write().await = Some(response.next_batch.clone());
// when events change state, updated_* signals to StateStore to update database
self.iter_joined_rooms(response).await?;
self.iter_invited_rooms(response).await?;

View file

@ -39,7 +39,10 @@ use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::warn;
use crate::olm::{InboundGroupSession, Session};
use crate::{
olm::{InboundGroupSession, Session},
store::{Changes, DeviceChanges},
};
#[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount};
@ -118,10 +121,15 @@ impl Device {
pub async fn set_local_trust(&self, trust_state: LocalTrust) -> StoreResult<()> {
self.inner.set_trust_state(trust_state);
self.verification_machine
.store
.save_devices(&[self.inner.clone()])
.await
let changes = Changes {
devices: DeviceChanges {
changed: vec![self.inner.clone()],
..Default::default()
},
..Default::default()
};
self.verification_machine.store.save_changes(changes).await
}
/// Encrypt the given content for this `Device`.
@ -135,7 +143,7 @@ impl Device {
&self,
event_type: EventType,
content: Value,
) -> OlmResult<EncryptedEventContent> {
) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner
.encrypt(&**self.verification_machine.store, event_type, content)
.await
@ -146,7 +154,7 @@ impl Device {
pub async fn encrypt_session(
&self,
session: InboundGroupSession,
) -> OlmResult<EncryptedEventContent> {
) -> OlmResult<(Session, EncryptedEventContent)> {
let export = session.export().await;
let content: ForwardedRoomKeyEventContent = if let Ok(c) = export.try_into() {
@ -364,7 +372,7 @@ impl ReadOnlyDevice {
store: &dyn CryptoStore,
event_type: EventType,
content: Value,
) -> OlmResult<EncryptedEventContent> {
) -> OlmResult<(Session, EncryptedEventContent)> {
let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
k
} else {
@ -396,10 +404,9 @@ impl ReadOnlyDevice {
return Err(OlmError::MissingSession);
};
let message = session.encrypt(&self, event_type, content).await;
store.save_sessions(&[session]).await?;
let message = session.encrypt(&self, event_type, content).await?;
message
Ok((session, message))
}
/// Update a device with a new device keys struct.

View file

@ -33,7 +33,7 @@ use crate::{
},
requests::KeysQueryRequest,
session_manager::GroupSessionManager,
store::{Result as StoreResult, Store},
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
};
#[derive(Debug, Clone)]
@ -79,7 +79,7 @@ impl IdentityManager {
pub async fn receive_keys_query_response(
&self,
response: &KeysQueryResponse,
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> {
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
// TODO create a enum that tells us how the device/identity changed,
// e.g. new/deleted/display name change.
//
@ -92,9 +92,15 @@ impl IdentityManager {
let changed_devices = self
.handle_devices_from_key_query(&response.device_keys)
.await?;
self.store.save_devices(&changed_devices).await?;
let changed_identities = self.handle_cross_singing_keys(response).await?;
self.store.save_user_identities(&changed_identities).await?;
let changes = Changes {
identities: changed_identities.clone(),
devices: changed_devices.clone(),
..Default::default()
};
self.store.save_changes(changes).await?;
Ok((changed_devices, changed_identities))
}
@ -111,9 +117,10 @@ impl IdentityManager {
async fn handle_devices_from_key_query(
&self,
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
) -> StoreResult<Vec<ReadOnlyDevice>> {
) -> StoreResult<DeviceChanges> {
let mut users_with_new_or_deleted_devices = HashSet::new();
let mut changed_devices = Vec::new();
let mut changes = DeviceChanges::default();
for (user_id, device_map) in device_keys_map {
// TODO move this out into the handle keys query response method
@ -137,7 +144,7 @@ impl IdentityManager {
let device = self.store.get_readonly_device(&user_id, device_id).await?;
let device = if let Some(mut device) = device {
if let Some(mut device) = device {
if let Err(e) = device.update_device(device_keys) {
warn!(
"Failed to update the device keys for {} {}: {:?}",
@ -145,7 +152,7 @@ impl IdentityManager {
);
continue;
}
device
changes.changed.push(device);
} else {
let device = match ReadOnlyDevice::try_from(device_keys) {
Ok(d) => d,
@ -159,23 +166,21 @@ impl IdentityManager {
};
info!("Adding a new device to the device store {:?}", device);
users_with_new_or_deleted_devices.insert(user_id);
device
};
changed_devices.push(device);
changes.new.push(device);
}
}
let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices = stored_devices_set.difference(&current_devices);
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices {
for device_id in deleted_devices_set {
users_with_new_or_deleted_devices.insert(user_id);
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
self.store.delete_device(device.clone()).await?;
changes.deleted.push(device.clone());
}
}
}
@ -183,7 +188,7 @@ impl IdentityManager {
self.group_manager
.invalidate_sessions_new_devices(&users_with_new_or_deleted_devices);
Ok(changed_devices)
Ok(changes)
}
/// Handle the device keys part of a key query response.
@ -197,8 +202,8 @@ impl IdentityManager {
async fn handle_cross_singing_keys(
&self,
response: &KeysQueryResponse,
) -> StoreResult<Vec<UserIdentities>> {
let mut changed = Vec::new();
) -> StoreResult<IdentityChanges> {
let mut changes = IdentityChanges::default();
for (user_id, master_key) in &response.master_keys {
let master_key = MasterPubkey::from(master_key);
@ -213,7 +218,7 @@ impl IdentityManager {
continue;
};
let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? {
let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? {
match &mut i {
UserIdentities::Own(ref mut identity) => {
let user_signing = if let Some(s) = response.user_signing_keys.get(user_id)
@ -230,11 +235,11 @@ impl IdentityManager {
identity
.update(master_key, self_signing, user_signing)
.map(|_| i)
}
UserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| i)
.map(|_| (i, false))
}
UserIdentities::Other(ref mut identity) => identity
.update(master_key, self_signing)
.map(|_| (i, false)),
}
} else if user_id == self.user_id() {
if let Some(s) = response.user_signing_keys.get(user_id) {
@ -252,7 +257,7 @@ impl IdentityManager {
}
OwnUserIdentity::new(master_key, self_signing, user_signing)
.map(UserIdentities::Own)
.map(|i| (UserIdentities::Own(i), true))
} else {
warn!(
"User identity for our own user {} didn't contain a \
@ -268,17 +273,22 @@ impl IdentityManager {
);
continue;
} else {
UserIdentity::new(master_key, self_signing).map(UserIdentities::Other)
UserIdentity::new(master_key, self_signing)
.map(|i| (UserIdentities::Other(i), true))
};
match identity {
Ok(i) => {
match result {
Ok((i, new)) => {
trace!(
"Updated or created new user identity for {}: {:?}",
user_id,
i
);
changed.push(i);
if new {
changes.new.push(i);
} else {
changes.changed.push(i);
}
}
Err(e) => {
warn!(
@ -290,7 +300,7 @@ impl IdentityManager {
}
}
Ok(changed)
Ok(changes)
}
/// Get a key query request if one is needed.

View file

@ -41,7 +41,7 @@ use matrix_sdk_common::{
use crate::{
error::{OlmError, OlmResult},
olm::{InboundGroupSession, OutboundGroupSession},
olm::{InboundGroupSession, OutboundGroupSession, Session},
requests::{OutgoingRequest, ToDeviceRequest},
store::{CryptoStoreError, Store},
Device,
@ -235,15 +235,18 @@ impl KeyRequestMachine {
/// Handle all the incoming key requests that are queued up and empty our
/// key request queue.
pub async fn collect_incoming_key_requests(&self) -> OlmResult<()> {
pub async fn collect_incoming_key_requests(&self) -> OlmResult<Vec<Session>> {
let mut changed_sessions = Vec::new();
for item in self.incoming_key_requests.iter() {
let event = item.value();
self.handle_key_request(event).await?;
if let Some(s) = self.handle_key_request(event).await? {
changed_sessions.push(s);
}
}
self.incoming_key_requests.clear();
Ok(())
Ok(changed_sessions)
}
/// Store the key share request for later, once we get an Olm session with
@ -294,7 +297,7 @@ impl KeyRequestMachine {
async fn handle_key_request(
&self,
event: &ToDeviceEvent<RoomKeyRequestEventContent>,
) -> OlmResult<()> {
) -> OlmResult<Option<Session>> {
let key_info = match event.content.action {
Action::Request => {
if let Some(info) = &event.content.body {
@ -305,11 +308,11 @@ impl KeyRequestMachine {
action, but no key info was found",
event.sender, event.content.requesting_device_id
);
return Ok(());
return Ok(None);
}
}
// We ignore cancellations here since there's nothing to serve.
Action::CancelRequest => return Ok(()),
Action::CancelRequest => return Ok(None),
};
let session = self
@ -328,7 +331,7 @@ impl KeyRequestMachine {
"Received a key request from {} {} for an unknown inbound group session {}.",
&event.sender, &event.content.requesting_device_id, &key_info.session_id
);
return Ok(());
return Ok(None);
};
let device = self
@ -349,6 +352,8 @@ impl KeyRequestMachine {
device.device_id(),
e
);
Ok(None)
} else {
info!(
"Serving a key request for {} from {} {}.",
@ -357,20 +362,20 @@ impl KeyRequestMachine {
device.device_id()
);
if let Err(e) = self.share_session(&session, &device).await {
match e {
OlmError::MissingSession => {
info!(
"Key request from {} {} is missing an Olm session, \
putting the request in the wait queue",
device.user_id(),
device.device_id()
);
self.handle_key_share_without_session(device, event);
return Ok(());
}
e => return Err(e),
match self.share_session(&session, &device).await {
Ok(s) => Ok(Some(s)),
Err(OlmError::MissingSession) => {
info!(
"Key request from {} {} is missing an Olm session, \
putting the request in the wait queue",
device.user_id(),
device.device_id()
);
self.handle_key_share_without_session(device, event);
Ok(None)
}
Err(e) => Err(e),
}
}
} else {
@ -379,13 +384,17 @@ impl KeyRequestMachine {
&event.sender, &event.content.requesting_device_id
);
self.store.update_tracked_user(&event.sender, true).await?;
}
Ok(())
Ok(None)
}
}
async fn share_session(&self, session: &InboundGroupSession, device: &Device) -> OlmResult<()> {
let content = device.encrypt_session(session.clone()).await?;
async fn share_session(
&self,
session: &InboundGroupSession,
device: &Device,
) -> OlmResult<Session> {
let (used_session, content) = device.encrypt_session(session.clone()).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
@ -412,7 +421,7 @@ impl KeyRequestMachine {
self.outgoing_to_device_requests.insert(id, request);
Ok(())
Ok(used_session)
}
/// Check if it's ok to share a session with the given device.
@ -569,23 +578,20 @@ impl KeyRequestMachine {
Ok(())
}
/// Save an inbound group session we received using a key forward.
/// Mark the given outgoing key info as done.
///
/// At the same time delete the key info since we received the wanted key.
async fn save_session(
&self,
key_info: OugoingKeyInfo,
session: InboundGroupSession,
) -> Result<(), CryptoStoreError> {
/// This will queue up a request cancelation.
async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> {
// TODO perhaps only remove the key info if the first known index is 0.
trace!(
"Successfully received a forwarded room key for {:#?}",
key_info
);
self.store.save_inbound_group_sessions(&[session]).await?;
self.outgoing_to_device_requests
.remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction.
self.delete_key_info(&key_info).await?;
let content = RoomKeyRequestEventContent {
@ -609,7 +615,8 @@ impl KeyRequestMachine {
&self,
sender_key: &str,
event: &mut ToDeviceEvent<ForwardedRoomKeyEventContent>,
) -> Result<Option<Raw<AnyToDeviceEvent>>, CryptoStoreError> {
) -> Result<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>), CryptoStoreError>
{
let key_info = self.get_key_info(&event.content).await?;
if let Some(info) = key_info {
@ -626,27 +633,32 @@ impl KeyRequestMachine {
// If we have a previous session, check if we have a better version
// and store the new one if so.
if let Some(old_session) = old_session {
let session = if let Some(old_session) = old_session {
let first_old_index = old_session.first_known_index().await;
let first_index = session.first_known_index().await;
if first_old_index > first_index {
self.save_session(info, session).await?;
self.mark_as_done(info).await?;
Some(session)
} else {
None
}
// If we didn't have a previous session, store it.
} else {
self.save_session(info, session).await?;
}
self.mark_as_done(info).await?;
Some(session)
};
Ok(Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey(
event.clone(),
))))
Ok((
Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey(event.clone()))),
session,
))
} else {
info!(
"Received a forwarded room key from {}, but no key info was found.",
event.sender,
);
Ok(None)
Ok((None, None))
}
}
}
@ -831,20 +843,20 @@ mod test {
.is_none()
);
machine
let (_, first_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let first_session = machine
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
let first_session = first_session.unwrap();
assert_eq!(first_session.first_known_index().await, 10);
machine
.store
.save_inbound_group_sessions(&[first_session.clone()])
.await
.unwrap();
// Get the cancel request.
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
let id = request.request_id;
@ -875,19 +887,12 @@ mod test {
content,
};
machine
let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let second_session = machine
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(second_session.first_known_index().await, 10);
assert!(second_session.is_none());
let export = session.export_at_index(0).await.unwrap();
@ -898,18 +903,12 @@ mod test {
content,
};
machine
let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let second_session = machine
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(second_session.first_known_index().await, 0);
assert_eq!(second_session.unwrap().first_known_index().await, 0);
}
#[async_test]
@ -1132,14 +1131,19 @@ mod test {
.unwrap()
.is_none());
let (decrypted, sender_key, _) =
let (_, decrypted, sender_key, _) =
alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() {
alice_machine
let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e)
.await
.unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else {
panic!("Invalid decrypted event type");
}
@ -1315,14 +1319,19 @@ mod test {
.unwrap()
.is_none());
let (decrypted, sender_key, _) =
let (_, decrypted, sender_key, _) =
alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() {
alice_machine
let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e)
.await
.unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else {
panic!("Invalid decrypted event type");
}

View file

@ -47,15 +47,18 @@ use matrix_sdk_common::{
use crate::store::sqlite::SqliteStore;
use crate::{
error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities},
identities::{Device, IdentityManager, UserDevices},
key_request::KeyRequestMachine,
olm::{
Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount,
InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session,
},
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
session_manager::{GroupSessionManager, SessionManager},
store::{CryptoStore, MemoryStore, Result as StoreResult, Store},
store::{
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
Store,
},
verification::{Sas, VerificationMachine},
ToDeviceRequest,
};
@ -467,7 +470,7 @@ impl OlmMachine {
async fn receive_keys_query_response(
&self,
response: &KeysQueryResponse,
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> {
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager
.receive_keys_query_response(response)
.await
@ -498,12 +501,12 @@ impl OlmMachine {
async fn decrypt_to_device_event(
&self,
event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<Raw<AnyToDeviceEvent>> {
let (decrypted_event, sender_key, signing_key) =
) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, Option<InboundGroupSession>)> {
let (session, decrypted_event, sender_key, signing_key) =
self.account.decrypt_to_device_event(event).await?;
// Handle the decrypted event, e.g. fetch out Megolm sessions out of
// the event.
if let Some(event) = self
if let (Some(event), group_session) = self
.handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event)
.await?
{
@ -512,9 +515,9 @@ impl OlmMachine {
// don't want them to be able to do silly things with it. Handling
// events modifies them and returns a modified one, so replace it
// here if we get one.
Ok(event)
Ok((session, event, group_session))
} else {
Ok(decrypted_event)
Ok((session, decrypted_event, None))
}
}
@ -524,7 +527,7 @@ impl OlmMachine {
sender_key: &str,
signing_key: &str,
event: &mut ToDeviceEvent<RoomKeyEventContent>,
) -> OlmResult<Option<Raw<AnyToDeviceEvent>>> {
) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> {
match event.content.algorithm {
EventEncryptionAlgorithm::MegolmV1AesSha2 => {
let session_key = GroupSessionKey(mem::take(&mut event.content.session_key));
@ -535,17 +538,15 @@ impl OlmMachine {
&event.content.room_id,
session_key,
)?;
let _ = self.store.save_inbound_group_sessions(&[session]).await?;
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
Ok(Some(event))
Ok((Some(event), Some(session)))
}
_ => {
warn!(
"Received room key with unsupported key algorithm {}",
event.content.algorithm
);
Ok(None)
Ok((None, None))
}
}
}
@ -555,9 +556,14 @@ impl OlmMachine {
&self,
room_id: &RoomId,
) -> OlmResult<()> {
self.group_session_manager
let (_, session) = self
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default())
.await
.await?;
self.store.save_inbound_group_sessions(&[session]).await?;
Ok(())
}
/// Encrypt a room message for the given room.
@ -647,12 +653,12 @@ impl OlmMachine {
sender_key: &str,
signing_key: &str,
event: &Raw<AnyToDeviceEvent>,
) -> OlmResult<Option<Raw<AnyToDeviceEvent>>> {
) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> {
let event = if let Ok(e) = event.deserialize() {
e
} else {
warn!("Decrypted to-device event failed to be parsed correctly");
return Ok(None);
return Ok((None, None));
};
match event {
@ -665,7 +671,7 @@ impl OlmMachine {
.await?),
_ => {
warn!("Received a unexpected encrypted to-device event");
Ok(None)
Ok((None, None))
}
}
}
@ -699,11 +705,8 @@ impl OlmMachine {
self.verification_machine.get_sas(flow_id)
}
async fn update_one_time_key_count(
&self,
key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>,
) -> StoreResult<()> {
self.account.update_uploaded_key_count(key_count).await
async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
self.account.update_uploaded_key_count(key_count).await;
}
/// Handle a sync response and update the internal state of the Olm machine.
@ -719,15 +722,19 @@ impl OlmMachine {
///
/// [`decrypt_room_event`]: #method.decrypt_room_event
#[instrument(skip(response))]
pub async fn receive_sync_response(&self, response: &mut SyncResponse) {
pub async fn receive_sync_response(&self, response: &mut SyncResponse) -> OlmResult<()> {
// Remove verification objects that have expired or are done.
self.verification_machine.garbage_collect();
if let Err(e) = self
.update_one_time_key_count(&response.device_one_time_keys_count)
.await
{
error!("Error updating the one-time key count {:?}", e);
}
// Always save the account, a new session might get created which also
// touches the account.
let mut changes = Changes {
account: Some(self.account.inner.clone()),
..Default::default()
};
self.update_one_time_key_count(&response.device_one_time_keys_count)
.await;
for user_id in &response.device_lists.changed {
if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await {
@ -748,29 +755,36 @@ impl OlmMachine {
match &mut event {
AnyToDeviceEvent::RoomEncrypted(e) => {
let decrypted_event = match self.decrypt_to_device_event(e).await {
Ok(e) => e,
Err(err) => {
warn!(
"Failed to decrypt to-device event from {} {}",
e.sender, err
);
let (session, decrypted_event, group_session) =
match self.decrypt_to_device_event(e).await {
Ok(e) => e,
Err(err) => {
warn!(
"Failed to decrypt to-device event from {} {}",
e.sender, err
);
if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self
.session_manager
.mark_device_as_wedged(&sender, &curve_key)
.await
{
error!(
"Couldn't mark device from {} to be unwedged {:?}",
sender, e
);
if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self
.session_manager
.mark_device_as_wedged(&sender, &curve_key)
.await
{
error!(
"Couldn't mark device from {} to be unwedged {:?}",
sender, e
);
}
}
continue;
}
continue;
}
};
};
changes.sessions.push(session);
if let Some(group_session) = group_session {
changes.inbound_group_sessions.push(group_session);
}
*event_result = decrypted_event;
}
@ -789,13 +803,14 @@ impl OlmMachine {
}
}
if let Err(e) = self
let changed_sessions = self
.key_request_machine
.collect_incoming_key_requests()
.await
{
error!("Error collecting our key share requests {:?}", e);
}
.await?;
changes.sessions.extend(changed_sessions);
Ok(self.store.save_changes(changes).await?)
}
/// Decrypt an event from a room timeline.
@ -973,7 +988,13 @@ impl OlmMachine {
let num_sessions = sessions.len();
self.store.save_inbound_group_sessions(&sessions).await?;
let changes = Changes {
inbound_group_sessions: sessions,
..Default::default()
};
self.store.save_changes(changes).await?;
info!(
"Successfully imported {} inbound group sessions",
num_sessions
@ -1198,15 +1219,19 @@ pub(crate) mod test {
.unwrap()
.unwrap();
let (session, content) = bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap();
alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent {
sender: alice.user_id().clone(),
content: bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap(),
content,
};
bob.decrypt_to_device_event(&event).await.unwrap();
let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store.save_sessions(&[session]).await.unwrap();
(alice, bob)
}
@ -1492,13 +1517,15 @@ pub(crate) mod test {
content: bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap(),
.unwrap()
.1,
};
let event = bob
.decrypt_to_device_event(&event)
.await
.unwrap()
.1
.deserialize()
.unwrap();
@ -1534,12 +1561,14 @@ pub(crate) mod test {
.get_outbound_group_session(&room_id)
.unwrap();
let event = bob
.decrypt_to_device_event(&event)
let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store.save_sessions(&[session]).await.unwrap();
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap()
.deserialize()
.unwrap();
let event = event.deserialize().unwrap();
if let AnyToDeviceEvent::RoomKey(event) = event {
assert_eq!(&event.sender, alice.user_id());
@ -1579,7 +1608,11 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests),
};
bob.decrypt_to_device_event(&event).await.unwrap();
let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap();
let plaintext = "It is a secret to everybody";

View file

@ -52,7 +52,7 @@ use olm_rs::{
use crate::{
error::{EventError, OlmResult, SessionCreationError},
identities::ReadOnlyDevice,
store::{Result as StoreResult, Store},
store::Store,
OlmError,
};
@ -76,7 +76,7 @@ impl Account {
pub async fn decrypt_to_device_event(
&self,
event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<(Raw<AnyToDeviceEvent>, String, String)> {
) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String, String)> {
debug!("Decrypting to-device event");
let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content {
@ -103,27 +103,28 @@ impl Account {
.map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it.
let (decrypted_event, signing_key) = self
let (session, decrypted_event, signing_key) = self
.decrypt_olm_message(&event.sender, &content.sender_key, message)
.await?;
debug!("Decrypted a to-device event {:?}", decrypted_event);
Ok((decrypted_event, content.sender_key.clone(), signing_key))
Ok((
session,
decrypted_event,
content.sender_key.clone(),
signing_key,
))
} else {
warn!("Olm event doesn't contain a ciphertext for our key");
Err(EventError::MissingCiphertext.into())
}
}
pub async fn update_uploaded_key_count(
&self,
key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>,
) -> StoreResult<()> {
pub async fn update_uploaded_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
let one_time_key_count = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
self.inner.update_uploaded_key_count(count);
self.store.save_account(self.inner.clone()).await
}
pub async fn receive_keys_upload_response(
@ -161,7 +162,7 @@ impl Account {
sender: &UserId,
sender_key: &str,
message: &OlmMessage,
) -> OlmResult<Option<String>> {
) -> OlmResult<Option<(Session, String)>> {
let s = self.store.get_sessions(sender_key).await?;
// We don't have any existing sessions, return early.
@ -171,8 +172,7 @@ impl Account {
return Ok(None);
};
let mut session_to_save = None;
let mut plaintext = None;
let mut decrypted: Option<(Session, String)> = None;
for session in &mut *sessions.lock().await {
let mut matches = false;
@ -191,9 +191,7 @@ impl Account {
match ret {
Ok(p) => {
plaintext = Some(p);
session_to_save = Some(session.clone());
decrypted = Some((session.clone(), p));
break;
}
Err(e) => {
@ -214,14 +212,7 @@ impl Account {
}
}
if let Some(session) = session_to_save {
// Decryption was successful, save the new ratchet state of the
// session that was used to decrypt the message.
trace!("Saved the new session state for {}", sender);
self.store.save_sessions(&[session]).await?;
}
Ok(plaintext)
Ok(decrypted)
}
/// Decrypt an Olm message, creating a new Olm session if possible.
@ -230,15 +221,15 @@ impl Account {
sender: &UserId,
sender_key: &str,
message: OlmMessage,
) -> OlmResult<(Raw<AnyToDeviceEvent>, String)> {
) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session.
let plaintext = if let Some(p) = self
let (session, plaintext) = if let Some(d) = self
.try_decrypt_olm_message(sender, sender_key, &message)
.await?
{
// Decryption succeeded, de-structure the plaintext out of the
// Option.
p
// Decryption succeeded, de-structure the session/plaintext out of
// the Option.
d
} else {
// Decryption failed with every known session, let's try to create a
// new session.
@ -278,9 +269,6 @@ impl Account {
}
};
// Save the account since we remove the one-time key that
// was used to create this session.
self.store.save_account(self.inner.clone()).await?;
session
}
};
@ -288,15 +276,23 @@ impl Account {
// Decrypt our message, this shouldn't fail since we're using a
// newly created Session.
let plaintext = session.decrypt(message).await?;
// Save the new ratcheted state of the session.
self.store.save_sessions(&[session]).await?;
plaintext
(session, plaintext)
};
trace!("Successfully decrypted a Olm message: {}", plaintext);
self.parse_decrypted_to_device_event(sender, &plaintext)
let (event, signing_key) = match self.parse_decrypted_to_device_event(sender, &plaintext) {
Ok(r) => r,
Err(e) => {
// We might created a new session but decryption might still
// have failed, store it for the error case here, this is fine
// since we don't expect this to happen often or at all.
self.store.save_sessions(&[session]).await?;
return Err(e);
}
};
Ok((session, event, signing_key))
}
/// Parse a decrypted Olm message, check that the plaintext and encrypted

View file

@ -28,8 +28,8 @@ use tracing::{debug, info};
use crate::{
error::{EventError, MegolmResult, OlmResult},
olm::{Account, OutboundGroupSession},
store::Store,
olm::{Account, InboundGroupSession, OutboundGroupSession},
store::{Changes, Store},
Device, EncryptionSettings, OlmError, ToDeviceRequest,
};
@ -140,19 +140,17 @@ impl GroupSessionManager {
&self,
room_id: &RoomId,
settings: EncryptionSettings,
) -> OlmResult<()> {
) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
let (outbound, inbound) = self
.account
.create_group_session_pair(room_id, settings)
.await
.map_err(|_| EventError::UnsupportedAlgorithm)?;
let _ = self.store.save_inbound_group_sessions(&[inbound]).await?;
let _ = self
.outbound_group_sessions
.insert(room_id.to_owned(), outbound);
Ok(())
.insert(room_id.to_owned(), outbound.clone());
Ok((outbound, inbound))
}
/// Get to-device requests to share a group session with users in a room.
@ -169,13 +167,12 @@ impl GroupSessionManager {
users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.create_outbound_group_session(room_id, encryption_settings.into())
.await?;
let session = self.outbound_group_sessions.get(room_id).unwrap();
let mut changes = Changes::default();
if session.shared() {
panic!("Session is already shared");
}
let (session, inbound_session) = self
.create_outbound_group_session(room_id, encryption_settings.into())
.await?;
changes.inbound_group_sessions.push(inbound_session);
let mut devices: Vec<Device> = Vec::new();
@ -196,7 +193,7 @@ impl GroupSessionManager {
.encrypt(EventType::RoomKey, key_content.clone())
.await;
let encrypted = match encrypted {
let (used_session, encrypted) = match encrypted {
Ok(c) => c,
Err(OlmError::MissingSession)
| Err(OlmError::EventError(EventError::MissingSenderKey)) => {
@ -205,6 +202,8 @@ impl GroupSessionManager {
Err(e) => return Err(e),
};
changes.sessions.push(used_session);
messages
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
@ -237,6 +236,8 @@ impl GroupSessionManager {
session.mark_as_shared();
}
self.store.save_changes(changes).await?;
Ok(requests)
}
}

View file

@ -33,7 +33,7 @@ use crate::{
key_request::KeyRequestMachine,
olm::Account,
requests::{OutgoingRequest, ToDeviceRequest},
store::{Result as StoreResult, Store},
store::{Changes, Result as StoreResult, Store},
ReadOnlyDevice,
};
@ -128,7 +128,7 @@ impl SessionManager {
.is_some()
{
if let Some(device) = self.store.get_device(user_id, device_id).await? {
let content = device.encrypt(EventType::Dummy, json!({})).await?;
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
@ -258,6 +258,8 @@ impl SessionManager {
pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
// TODO log the failures here
let mut changes = Changes::default();
for (user_id, user_devices) in &response.one_time_keys {
for (device_id, key_map) in user_devices {
let device = match self.store.get_readonly_device(&user_id, device_id).await {
@ -284,15 +286,12 @@ impl SessionManager {
let session = match self.account.create_outbound_session(device, &key_map).await {
Ok(s) => s,
Err(e) => {
warn!("{:?}", e);
warn!("Error creating new outbound session {:?}", e);
continue;
}
};
if let Err(e) = self.store.save_sessions(&[session]).await {
error!("Failed to store newly created Olm session {}", e);
continue;
}
changes.sessions.push(session);
self.key_request_machine.retry_keyshare(&user_id, device_id);
@ -304,7 +303,8 @@ impl SessionManager {
}
}
}
Ok(())
Ok(self.store.save_changes(changes).await?)
}
}

View file

@ -26,7 +26,7 @@ use matrix_sdk_common_macros::async_trait;
use super::{
caches::{DeviceStore, GroupSessionStore, SessionStore},
CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
};
use crate::identities::{ReadOnlyDevice, UserIdentities};
@ -61,6 +61,30 @@ impl MemoryStore {
pub fn new() -> Self {
Self::default()
}
pub(crate) async fn save_devices(&self, mut devices: Vec<ReadOnlyDevice>) {
for device in devices.drain(..) {
let _ = self.devices.add(device);
}
}
async fn delete_devices(&self, mut devices: Vec<ReadOnlyDevice>) {
for device in devices.drain(..) {
let _ = self.devices.remove(device.user_id(), device.device_id());
}
}
async fn save_sessions(&self, mut sessions: Vec<Session>) {
for session in sessions.drain(..) {
let _ = self.sessions.add(session.clone()).await;
}
}
async fn save_inbound_group_sessions(&self, mut sessions: Vec<InboundGroupSession>) {
for session in sessions.drain(..) {
self.inbound_group_sessions.add(session);
}
}
}
#[async_trait]
@ -73,9 +97,24 @@ impl CryptoStore for MemoryStore {
Ok(())
}
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
for session in sessions {
let _ = self.sessions.add(session.clone()).await;
async fn save_changes(&self, mut changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions).await;
self.save_inbound_group_sessions(changes.inbound_group_sessions)
.await;
self.save_devices(changes.devices.new).await;
self.save_devices(changes.devices.changed).await;
self.delete_devices(changes.devices.deleted).await;
for identity in changes
.identities
.new
.drain(..)
.chain(changes.identities.changed)
{
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
}
Ok(())
@ -85,14 +124,6 @@ impl CryptoStore for MemoryStore {
Ok(self.sessions.get(sender_key))
}
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
for session in sessions {
self.inbound_group_sessions.add(session.clone());
}
Ok(())
}
async fn get_inbound_group_session(
&self,
room_id: &RoomId,
@ -151,11 +182,6 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.get(user_id, device_id))
}
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> {
let _ = self.devices.remove(device.user_id(), device.device_id());
Ok(())
}
async fn get_user_devices(
&self,
user_id: &UserId,
@ -163,28 +189,11 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.user_devices(user_id))
}
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
for device in devices {
let _ = self.devices.add(device.clone());
}
Ok(())
}
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
#[allow(clippy::map_clone)]
Ok(self.identities.get(user_id).map(|i| i.clone()))
}
async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()> {
for identity in identities {
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
}
Ok(())
}
async fn save_value(&self, key: String, value: String) -> Result<()> {
self.values.insert(key, value);
Ok(())
@ -217,7 +226,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none());
store.save_account(account).await.unwrap();
store.save_sessions(&[session.clone()]).await.unwrap();
store.save_sessions(vec![session.clone()]).await;
let sessions = store
.get_sessions(&session.sender_key)
@ -250,9 +259,8 @@ mod test {
let store = MemoryStore::new();
let _ = store
.save_inbound_group_sessions(&[inbound.clone()])
.await
.unwrap();
.save_inbound_group_sessions(vec![inbound.clone()])
.await;
let loaded_session = store
.get_inbound_group_session(&room_id, "test_key", outbound.session_id())
@ -267,7 +275,7 @@ mod test {
let device = get_device();
let store = MemoryStore::new();
store.save_devices(&[device.clone()]).await.unwrap();
store.save_devices(vec![device.clone()]).await;
let loaded_device = store
.get_device(device.user_id(), device.device_id())
@ -286,7 +294,7 @@ mod test {
assert_eq!(&device, loaded_device);
store.delete_device(device.clone()).await.unwrap();
store.delete_devices(vec![device.clone()]).await;
assert!(store
.get_device(device.user_id(), device.device_id())
.await

View file

@ -100,6 +100,31 @@ pub(crate) struct Store {
verification_machine: VerificationMachine,
}
#[derive(Debug, Default)]
#[allow(missing_docs)]
pub struct Changes {
pub account: Option<ReadOnlyAccount>,
pub sessions: Vec<Session>,
pub inbound_group_sessions: Vec<InboundGroupSession>,
pub identities: IdentityChanges,
pub devices: DeviceChanges,
}
#[derive(Debug, Clone, Default)]
#[allow(missing_docs)]
pub struct IdentityChanges {
pub new: Vec<UserIdentities>,
pub changed: Vec<UserIdentities>,
}
#[derive(Debug, Clone, Default)]
#[allow(missing_docs)]
pub struct DeviceChanges {
pub new: Vec<ReadOnlyDevice>,
pub changed: Vec<ReadOnlyDevice>,
pub deleted: Vec<ReadOnlyDevice>,
}
impl Store {
pub fn new(
user_id: Arc<UserId>,
@ -121,6 +146,41 @@ impl Store {
self.inner.get_device(user_id, device_id).await
}
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes {
sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await
}
#[cfg(test)]
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let changes = Changes {
devices: DeviceChanges {
changed: devices.to_vec(),
..Default::default()
},
..Default::default()
};
self.save_changes(changes).await
}
#[cfg(test)]
pub async fn save_inbound_group_sessions(
&self,
sessions: &[InboundGroupSession],
) -> Result<()> {
let changes = Changes {
inbound_group_sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await
}
pub async fn get_readonly_devices(
&self,
user_id: &UserId,
@ -271,12 +331,8 @@ pub trait CryptoStore: Debug {
/// * `account` - The account that should be stored.
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>;
/// Save the given sessions in the store.
///
/// # Arguments
///
/// * `session` - The sessions that should be stored.
async fn save_sessions(&self, session: &[Session]) -> Result<()>;
/// TODO
async fn save_changes(&self, changes: Changes) -> Result<()>;
/// Get all the sessions that belong to the given sender key.
///
@ -285,13 +341,6 @@ pub trait CryptoStore: Debug {
/// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
/// Save the given inbound group sessions in the store.
///
/// # Arguments
///
/// * `sessions` - The sessions that should be stored.
async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>;
/// Get the inbound group session from our store.
///
/// # Arguments
@ -331,20 +380,6 @@ pub trait CryptoStore: Debug {
/// * `dirty` - Should the user be also marked for a key query.
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>;
/// Save the given devices in the store.
///
/// # Arguments
///
/// * `device` - The device that should be stored.
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()>;
/// Delete the given device from the store.
///
/// # Arguments
///
/// * `device` - The device that should be stored.
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()>;
/// Get the device for the given user with the given device id.
///
/// # Arguments
@ -368,13 +403,6 @@ pub trait CryptoStore: Debug {
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>>;
/// Save the given user identities in the store.
///
/// # Arguments
///
/// * `identities` - The identities that should be saved in the store.
async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()>;
/// Get the user identity that is attached to the given user id.
///
/// # Arguments

View file

@ -34,7 +34,7 @@ use matrix_sdk_common::{
use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection};
use zeroize::Zeroizing;
use super::{caches::SessionStore, CryptoStore, CryptoStoreError, Result};
use super::{caches::SessionStore, Changes, CryptoStore, CryptoStoreError, Result};
use crate::{
identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity},
olm::{
@ -456,11 +456,17 @@ impl SqliteStore {
Ok(())
}
async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> {
async fn lazy_load_sessions(
&self,
connection: &mut SqliteConnection,
sender_key: &str,
) -> Result<()> {
let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions {
let sessions = self.load_sessions_for(sender_key).await?;
let sessions = self
.load_sessions_for_helper(connection, sender_key)
.await?;
if !sessions.is_empty() {
self.sessions.set_for_sender(sender_key, sessions);
@ -470,20 +476,33 @@ impl SqliteStore {
Ok(())
}
async fn get_sessions_for(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
self.lazy_load_sessions(sender_key).await?;
async fn get_sessions_for(
&self,
connection: &mut SqliteConnection,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
self.lazy_load_sessions(connection, sender_key).await?;
Ok(self.sessions.get(sender_key))
}
#[cfg(test)]
async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> {
let mut connection = self.connection.lock().await;
self.load_sessions_for_helper(&mut connection, sender_key)
.await
}
async fn load_sessions_for_helper(
&self,
connection: &mut SqliteConnection,
sender_key: &str,
) -> Result<Vec<Session>> {
let account_info = self
.account_info
.lock()
.unwrap()
.clone()
.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let mut rows: Vec<(String, String, String, String)> = query_as(
"SELECT pickle, sender_key, creation_time, last_use_time
FROM sessions WHERE account_id = ? and sender_key = ?",
@ -1231,6 +1250,134 @@ impl SqliteStore {
Ok(())
}
#[cfg(test)]
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
self.save_sessions_helper(&mut transaction, sessions)
.await?;
transaction.commit().await?;
Ok(())
}
async fn save_sessions_helper(
&self,
connection: &mut SqliteConnection,
sessions: &[Session],
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for session in sessions {
self.lazy_load_sessions(connection, &session.sender_key)
.await?;
}
for session in sessions {
self.sessions.add(session.clone()).await;
let pickle = session.pickle(self.get_pickle_mode()).await;
let session_id = session.session_id();
let creation_time = serde_json::to_string(&pickle.creation_time)?;
let last_use_time = serde_json::to_string(&pickle.last_use_time)?;
query(
"REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle
) VALUES (?, ?, ?, ?, ?, ?)",
)
.bind(&session_id)
.bind(&account_id)
.bind(&*creation_time)
.bind(&*last_use_time)
.bind(&pickle.sender_key)
.bind(&pickle.pickle.as_str())
.execute(&mut *connection)
.await?;
}
Ok(())
}
async fn save_devices(
&self,
mut connection: &mut SqliteConnection,
devices: &[ReadOnlyDevice],
) -> Result<()> {
for device in devices {
self.save_device_helper(&mut connection, device.clone())
.await?
}
Ok(())
}
async fn delete_devices(
&self,
connection: &mut SqliteConnection,
devices: &[ReadOnlyDevice],
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for device in devices {
query(
"DELETE FROM devices
WHERE account_id = ?1 and user_id = ?2 and device_id = ?3
",
)
.bind(account_id)
.bind(&device.user_id().to_string())
.bind(device.device_id().as_str())
.execute(&mut *connection)
.await?;
}
Ok(())
}
#[cfg(test)]
async fn save_inbound_group_sessions_test(
&self,
sessions: &[InboundGroupSession],
) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
self.save_inbound_group_sessions(&mut transaction, sessions)
.await?;
transaction.commit().await?;
Ok(())
}
async fn save_inbound_group_sessions(
&self,
connection: &mut SqliteConnection,
sessions: &[InboundGroupSession],
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for session in sessions {
self.save_inbound_group_session_helper(account_id, connection, session)
.await?;
}
Ok(())
}
async fn save_user_identities(
&self,
mut connection: &mut SqliteConnection,
users: &[UserIdentities],
) -> Result<()> {
for user in users {
self.save_user_helper(&mut connection, user).await?;
}
Ok(())
}
async fn save_user_helper(
&self,
mut connection: &mut SqliteConnection,
@ -1369,39 +1516,26 @@ impl CryptoStore for SqliteStore {
Ok(())
}
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?;
}
async fn save_changes(&self, changes: Changes) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for session in sessions {
self.sessions.add(session.clone()).await;
let pickle = session.pickle(self.get_pickle_mode()).await;
let session_id = session.session_id();
let creation_time = serde_json::to_string(&pickle.creation_time)?;
let last_use_time = serde_json::to_string(&pickle.last_use_time)?;
query(
"REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle
) VALUES (?, ?, ?, ?, ?, ?)",
)
.bind(&session_id)
.bind(&account_id)
.bind(&*creation_time)
.bind(&*last_use_time)
.bind(&pickle.sender_key)
.bind(&pickle.pickle.as_str())
.execute(&mut *transaction)
self.save_sessions_helper(&mut transaction, &changes.sessions)
.await?;
self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions)
.await?;
self.save_devices(&mut transaction, &changes.devices.new)
.await?;
self.save_devices(&mut transaction, &changes.devices.changed)
.await?;
self.delete_devices(&mut transaction, &changes.devices.deleted)
.await?;
self.save_user_identities(&mut transaction, &changes.identities.new)
.await?;
self.save_user_identities(&mut transaction, &changes.identities.changed)
.await?;
}
transaction.commit().await?;
@ -1409,22 +1543,8 @@ impl CryptoStore for SqliteStore {
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
Ok(self.get_sessions_for(sender_key).await?)
}
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for session in sessions {
self.save_inbound_group_session_helper(account_id, &mut transaction, session)
.await?;
}
transaction.commit().await?;
Ok(())
Ok(self.get_sessions_for(&mut connection, sender_key).await?)
}
async fn get_inbound_group_session(
@ -1469,38 +1589,6 @@ impl CryptoStore for SqliteStore {
Ok(already_added)
}
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for device in devices {
self.save_device_helper(&mut transaction, device.clone())
.await?
}
transaction.commit().await?;
Ok(())
}
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
query(
"DELETE FROM devices
WHERE account_id = ?1 and user_id = ?2 and device_id = ?3
",
)
.bind(account_id)
.bind(&device.user_id().to_string())
.bind(device.device_id().as_str())
.execute(&mut *connection)
.await?;
Ok(())
}
async fn get_device(
&self,
user_id: &UserId,
@ -1520,19 +1608,6 @@ impl CryptoStore for SqliteStore {
self.load_user(user_id).await
}
async fn save_user_identities(&self, users: &[UserIdentities]) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for user in users {
self.save_user_helper(&mut transaction, user).await?;
}
transaction.commit().await?;
Ok(())
}
async fn save_value(&self, key: String, value: String) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
@ -1598,6 +1673,7 @@ mod test {
user::test::{get_other_identity, get_own_identity},
},
olm::{GroupSessionKey, InboundGroupSession, ReadOnlyAccount, Session},
store::{Changes, DeviceChanges, IdentityChanges},
};
use matrix_sdk_common::{
api::r0::keys::SignedKey,
@ -1854,7 +1930,7 @@ mod test {
.expect("Can't create session");
store
.save_inbound_group_sessions(&[session])
.save_inbound_group_sessions_test(&[session])
.await
.expect("Can't save group session");
}
@ -1880,7 +1956,7 @@ mod test {
let session = InboundGroupSession::from_export(export).unwrap();
store
.save_inbound_group_sessions(&[session.clone()])
.save_inbound_group_sessions_test(&[session.clone()])
.await
.expect("Can't save group session");
@ -1952,7 +2028,15 @@ mod test {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
store.save_devices(&[device.clone()]).await.unwrap();
let changes = Changes {
devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
drop(store);
@ -1986,8 +2070,25 @@ mod test {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
store.save_devices(&[device.clone()]).await.unwrap();
store.delete_device(device.clone()).await.unwrap();
let changes = Changes {
devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let changes = Changes {
devices: DeviceChanges {
deleted: vec![device.clone()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path())
.await
@ -2024,8 +2125,16 @@ mod test {
let own_identity = get_own_identity();
let changes = Changes {
identities: IdentityChanges {
changed: vec![own_identity.clone().into()],
..Default::default()
},
..Default::default()
};
store
.save_user_identities(&[own_identity.clone().into()])
.save_changes(changes)
.await
.expect("Can't save identity");
@ -2052,10 +2161,15 @@ mod test {
let other_identity = get_other_identity();
store
.save_user_identities(&[other_identity.clone().into()])
.await
.unwrap();
let changes = Changes {
identities: IdentityChanges {
changed: vec![other_identity.clone().into()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let loaded_user = store
.load_user(other_identity.user_id())
@ -2072,10 +2186,15 @@ mod test {
own_identity.mark_as_verified();
store
.save_user_identities(&[own_identity.into()])
.await
.unwrap();
let changes = Changes {
identities: IdentityChanges {
changed: vec![own_identity.into()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let loaded_user = store.load_user(&user_id).await.unwrap().unwrap();
assert!(loaded_user.own().unwrap().is_verified())
}

View file

@ -194,18 +194,14 @@ impl VerificationMachine {
self.receive_event_helper(&s, event);
if s.is_done() {
if !s.mark_device_as_verified().await? {
if let Some(r) = s.cancel() {
self.outgoing_to_device_messages.insert(
r.txn_id,
OutgoingRequest {
request_id: r.txn_id,
request: Arc::new(r.into()),
},
);
}
} else {
s.mark_identity_as_verified().await?;
if let Some(r) = s.mark_as_done().await? {
self.outgoing_to_device_messages.insert(
r.txn_id,
OutgoingRequest {
request_id: r.txn_id,
request: Arc::new(r.into()),
},
);
}
}
};
@ -258,17 +254,15 @@ mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
let store = MemoryStore::new();
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let bob_store = MemoryStore::new();
let bob_device = ReadOnlyDevice::from_account(&bob).await;
let alice_device = ReadOnlyDevice::from_account(&alice).await;
store.save_devices(&[bob_device]).await.unwrap();
bob_store
.save_devices(&[alice_device.clone()])
.await
.unwrap();
store.save_devices(vec![bob_device]).await;
bob_store.save_devices(vec![alice_device.clone()]).await;
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(bob_store));
let machine = VerificationMachine::new(alice, Arc::new(Box::new(store)));
let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store, None);
machine

View file

@ -34,7 +34,7 @@ use matrix_sdk_common::{
use crate::{
identities::{LocalTrust, ReadOnlyDevice, UserIdentities},
store::{CryptoStore, CryptoStoreError},
store::{Changes, CryptoStore, CryptoStoreError, DeviceChanges},
ReadOnlyAccount, ToDeviceRequest,
};
@ -189,34 +189,64 @@ impl Sas {
(content, guard.is_done())
};
if done {
// TODO move the logic that marks and stores the device into the
// else branch and only after the identity was verified as well. We
// dont' want to verify one without the other.
if !self.mark_device_as_verified().await? {
return Ok(self.cancel());
} else {
self.mark_identity_as_verified().await?;
}
}
let cancel = if done {
self.mark_as_done().await?
} else {
None
};
Ok(content.map(|c| {
let content = AnyToDeviceEventContent::KeyVerificationMac(c);
self.content_to_request(content)
}))
if cancel.is_some() {
Ok(cancel)
} else {
Ok(content.map(|c| {
let content = AnyToDeviceEventContent::KeyVerificationMac(c);
self.content_to_request(content)
}))
}
}
pub(crate) async fn mark_identity_as_verified(&self) -> Result<bool, CryptoStoreError> {
pub(crate) async fn mark_as_done(&self) -> Result<Option<ToDeviceRequest>, CryptoStoreError> {
if let Some(device) = self.mark_device_as_verified().await? {
let identity = self.mark_identity_as_verified().await?;
let mut changes = Changes {
devices: DeviceChanges {
changed: vec![device],
..Default::default()
},
..Default::default()
};
if let Some(i) = identity {
changes.identities.changed.push(i);
}
self.store.save_changes(changes).await?;
Ok(None)
} else {
Ok(self.cancel())
}
}
pub(crate) async fn mark_identity_as_verified(
&self,
) -> Result<Option<UserIdentities>, CryptoStoreError> {
// If there wasn't an identity available during the verification flow
// return early as there's nothing to do.
if self.other_identity.is_none() {
return Ok(false);
return Ok(None);
}
// TODO signal an error, e.g. when the identity got deleted so we don't
// verify/save the device either.
let identity = self.store.get_user_identity(self.other_user_id()).await?;
if let Some(identity) = identity {
if identity.master_key() == self.other_identity.as_ref().unwrap().master_key() {
if self
.other_identity
.as_ref()
.map_or(false, |i| i.master_key() == identity.master_key())
{
if self
.verified_identities()
.map_or(false, |i| i.contains(&identity))
@ -228,13 +258,12 @@ impl Sas {
if let UserIdentities::Own(i) = &identity {
i.mark_as_verified();
self.store.save_user_identities(&[identity]).await?;
}
// TODO if we have the private part of the user signing
// key we should sign and upload a signature for this
// identity.
Ok(true)
Ok(Some(identity))
} else {
info!(
"The interactive verification process didn't contain a \
@ -243,7 +272,7 @@ impl Sas {
self.verified_identities(),
);
Ok(false)
Ok(None)
}
} else {
warn!(
@ -252,7 +281,7 @@ impl Sas {
identity.user_id(),
);
Ok(false)
Ok(None)
}
} else {
info!(
@ -260,11 +289,13 @@ impl Sas {
verification was going on.",
self.other_user_id(),
);
Ok(false)
Ok(None)
}
}
pub(crate) async fn mark_device_as_verified(&self) -> Result<bool, CryptoStoreError> {
pub(crate) async fn mark_device_as_verified(
&self,
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self
.store
.get_device(self.other_user_id(), self.other_device_id())
@ -283,12 +314,11 @@ impl Sas {
);
device.set_trust_state(LocalTrust::Verified);
self.store.save_devices(&[device]).await?;
// TODO if this is a device from our own user and we have
// the private part of the self signing key, we should sign
// the device and upload the signature.
Ok(true)
Ok(Some(device))
} else {
info!(
"The interactive verification process didn't contain a \
@ -297,7 +327,7 @@ impl Sas {
device.device_id()
);
Ok(false)
Ok(None)
}
} else {
warn!(
@ -306,7 +336,7 @@ impl Sas {
device.user_id(),
device.device_id()
);
Ok(false)
Ok(None)
}
} else {
let device = self.other_device();
@ -317,7 +347,7 @@ impl Sas {
device.user_id(),
device.device_id()
);
Ok(false)
Ok(None)
}
}
@ -777,12 +807,11 @@ mod test {
let bob_device = ReadOnlyDevice::from_account(&bob).await;
let alice_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let bob_store = MemoryStore::new();
bob_store
.save_devices(&[alice_device.clone()])
.await
.unwrap();
bob_store.save_devices(vec![alice_device.clone()]).await;
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(bob_store));
let (alice, content) = Sas::start(alice, bob_device, alice_store, None);
let event = wrap_to_device_event(alice.user_id(), content);