diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index d6917de1..e5c6b408 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -16,7 +16,7 @@ #[cfg(feature = "encryption")] use std::{collections::BTreeMap, io::Write, path::PathBuf}; use std::{ - convert::{TryFrom, TryInto}, + convert::TryInto, fmt::{self, Debug}, future::Future, io::Read, @@ -803,7 +803,7 @@ impl Client { since: Option<&str>, server: Option<&ServerName>, ) -> Result { - let limit = limit.map(|n| UInt::try_from(n).ok()).flatten(); + let limit = limit.map(|n| UInt::from(n)); let request = assign!(get_public_rooms::Request::new(), { limit, diff --git a/matrix_sdk_crypto/src/error.rs b/matrix_sdk_crypto/src/error.rs index 6d9aacdc..24d68f1c 100644 --- a/matrix_sdk_crypto/src/error.rs +++ b/matrix_sdk_crypto/src/error.rs @@ -47,9 +47,16 @@ pub enum OlmError { Store(#[from] CryptoStoreError), /// The session with a device has become corrupted. - #[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")] + #[error( + "decryption failed likely because an Olm session from {0} with sender key {1} was wedged" + )] SessionWedged(UserId, String), + /// An Olm message got replayed while the Olm ratchet has already moved + /// forward. + #[error("decryption failed because an Olm message from {0} with sender key {1} was replayed")] + ReplayedMessage(UserId, String), + /// Encryption failed because the device does not have a valid Olm session /// with us. #[error( diff --git a/matrix_sdk_crypto/src/file_encryption/mod.rs b/matrix_sdk_crypto/src/file_encryption/mod.rs index 3c644bd9..45ccd310 100644 --- a/matrix_sdk_crypto/src/file_encryption/mod.rs +++ b/matrix_sdk_crypto/src/file_encryption/mod.rs @@ -14,10 +14,10 @@ fn decode_url_safe(input: impl AsRef<[u8]>) -> Result, DecodeError> { decode_config(input, URL_SAFE_NO_PAD) } -fn encode(input: impl AsRef<[u8]>) -> String { +pub fn encode(input: impl AsRef<[u8]>) -> String { encode_config(input, STANDARD_NO_PAD) } -fn encode_url_safe(input: impl AsRef<[u8]>) -> String { +pub fn encode_url_safe(input: impl AsRef<[u8]>) -> String { encode_config(input, URL_SAFE_NO_PAD) } diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 55f5b569..32cbd69f 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -1137,12 +1137,11 @@ mod test { .unwrap() .is_none()); - let (_, decrypted, sender_key, _) = - alice_account.decrypt_to_device_event(&event).await.unwrap(); + let decrypted = alice_account.decrypt_to_device_event(&event).await.unwrap(); - if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { + if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.event.deserialize().unwrap() { let (_, session) = alice_machine - .receive_forwarded_room_key(&sender_key, &mut e) + .receive_forwarded_room_key(&decrypted.sender_key, &mut e) .await .unwrap(); alice_machine @@ -1157,7 +1156,11 @@ mod test { // Check that alice now does have the session. let session = alice_machine .store - .get_inbound_group_session(&room_id(), &sender_key, group_session.session_id()) + .get_inbound_group_session( + &room_id(), + &decrypted.sender_key, + group_session.session_id(), + ) .await .unwrap() .unwrap(); @@ -1325,12 +1328,11 @@ mod test { .unwrap() .is_none()); - let (_, decrypted, sender_key, _) = - alice_account.decrypt_to_device_event(&event).await.unwrap(); + let decrypted = alice_account.decrypt_to_device_event(&event).await.unwrap(); - if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { + if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.event.deserialize().unwrap() { let (_, session) = alice_machine - .receive_forwarded_room_key(&sender_key, &mut e) + .receive_forwarded_room_key(&decrypted.sender_key, &mut e) .await .unwrap(); alice_machine @@ -1345,7 +1347,11 @@ mod test { // Check that alice now does have the session. let session = alice_machine .store - .get_inbound_group_session(&room_id(), &sender_key, group_session.session_id()) + .get_inbound_group_session( + &room_id(), + &decrypted.sender_key, + group_session.session_id(), + ) .await .unwrap() .unwrap(); diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 83b5939a..9252e678 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -52,7 +52,8 @@ use crate::{ key_request::KeyRequestMachine, olm::{ Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, - InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, + InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, ReadOnlyAccount, + SessionType, }, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, @@ -365,10 +366,15 @@ impl OlmMachine { /// Mark the cross signing identity as shared. async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> { - self.user_identity.lock().await.mark_as_shared(); - self.store - .save_identity((&*self.user_identity.lock().await).clone()) - .await + let identity = self.user_identity.lock().await; + identity.mark_as_shared(); + + let changes = Changes { + private_identity: Some(identity.clone()), + ..Default::default() + }; + + self.store.save_changes(changes).await } /// Create a new cross signing identity and get the upload request to push @@ -400,11 +406,12 @@ impl OlmMachine { new: vec![public.into()], ..Default::default() }, + private_identity: Some(identity.clone()), ..Default::default() }; self.store.save_changes(changes).await?; - self.store.save_identity(identity.clone()).await?; + Ok((request, signature_request)) } else { info!("Trying to upload the existing cross signing identity"); @@ -555,24 +562,23 @@ impl OlmMachine { async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult<(Session, Raw, Option)> { - let (session, decrypted_event, sender_key, signing_key) = - self.account.decrypt_to_device_event(event).await?; + ) -> OlmResult { + let mut decrypted = self.account.decrypt_to_device_event(event).await?; // Handle the decrypted event, e.g. fetch out Megolm sessions out of // the event. - if let (Some(event), group_session) = self - .handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event) - .await? + if let (Some(event), group_session) = + self.handle_decrypted_to_device_event(&decrypted).await? { // Some events may have sensitive data e.g. private keys, while we // want to notify our users that a private key was received we // don't want them to be able to do silly things with it. Handling // events modifies them and returns a modified one, so replace it // here if we get one. - Ok((session, event, group_session)) - } else { - Ok((session, decrypted_event, None)) + decrypted.event = event; + decrypted.inbound_group_session = group_session; } + + Ok(decrypted) } /// Create a group session from a room key and add it to our crypto store. @@ -704,27 +710,29 @@ impl OlmMachine { /// * `event` - The decrypted to-device event. async fn handle_decrypted_to_device_event( &self, - sender_key: &str, - signing_key: &str, - event: &Raw, + decrypted: &OlmDecryptionInfo, ) -> OlmResult<(Option>, Option)> { - let event = if let Ok(e) = event.deserialize() { - e - } else { - warn!("Decrypted to-device event failed to be parsed correctly"); - return Ok((None, None)); + let event = match decrypted.event.deserialize() { + Ok(e) => e, + Err(e) => { + warn!( + "Decrypted to-device event failed to be parsed correctly {:?}", + e + ); + return Ok((None, None)); + } }; match event { - AnyToDeviceEvent::RoomKey(mut e) => { - Ok(self.add_room_key(sender_key, signing_key, &mut e).await?) - } + AnyToDeviceEvent::RoomKey(mut e) => Ok(self + .add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e) + .await?), AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self .key_request_machine - .receive_forwarded_room_key(sender_key, &mut e) + .receive_forwarded_room_key(&decrypted.sender_key, &mut e) .await?), _ => { - warn!("Received a unexpected encrypted to-device event"); + warn!("Received an unexpected encrypted to-device event"); Ok((None, None)) } } @@ -808,38 +816,49 @@ impl OlmMachine { match &mut event { AnyToDeviceEvent::RoomEncrypted(e) => { - let (session, decrypted_event, group_session) = - match self.decrypt_to_device_event(e).await { - Ok(e) => e, - Err(err) => { - warn!( - "Failed to decrypt to-device event from {} {}", - e.sender, err - ); + let decrypted = match self.decrypt_to_device_event(e).await { + Ok(e) => e, + Err(err) => { + warn!( + "Failed to decrypt to-device event from {} {}", + e.sender, err + ); - if let OlmError::SessionWedged(sender, curve_key) = err { - if let Err(e) = self - .session_manager - .mark_device_as_wedged(&sender, &curve_key) - .await - { - error!( - "Couldn't mark device from {} to be unwedged {:?}", - sender, e - ); - } + if let OlmError::SessionWedged(sender, curve_key) = err { + if let Err(e) = self + .session_manager + .mark_device_as_wedged(&sender, &curve_key) + .await + { + error!( + "Couldn't mark device from {} to be unwedged {:?}", + sender, e + ); } - continue; } - }; + continue; + } + }; - changes.sessions.push(session); + // New sessions modify the account so we need to save that + // one as well. + match decrypted.session { + SessionType::New(s) => { + changes.sessions.push(s); + changes.account = Some(self.account.inner.clone()); + } + SessionType::Existing(s) => { + changes.sessions.push(s); + } + } - if let Some(group_session) = group_session { + changes.message_hashes.push(decrypted.message_hash); + + if let Some(group_session) = decrypted.inbound_group_session { changes.inbound_group_sessions.push(group_session); } - *event_result = decrypted_event; + *event_result = decrypted.event; } AnyToDeviceEvent::RoomKeyRequest(e) => { self.key_request_machine.receive_incoming_key_request(e) @@ -1283,8 +1302,11 @@ pub(crate) mod test { content, }; - let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap(); - bob.store.save_sessions(&[session]).await.unwrap(); + let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); + bob.store + .save_sessions(&[decrypted.session.session()]) + .await + .unwrap(); (alice, bob) } @@ -1578,7 +1600,7 @@ pub(crate) mod test { .decrypt_to_device_event(&event) .await .unwrap() - .1 + .event .deserialize() .unwrap(); @@ -1614,14 +1636,17 @@ pub(crate) mod test { .get_outbound_group_session(&room_id) .unwrap(); - let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); - bob.store.save_sessions(&[session]).await.unwrap(); bob.store - .save_inbound_group_sessions(&[group_session.unwrap()]) + .save_sessions(&[decrypted.session.session()]) .await .unwrap(); - let event = event.deserialize().unwrap(); + bob.store + .save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()]) + .await + .unwrap(); + let event = decrypted.event.deserialize().unwrap(); if let AnyToDeviceEvent::RoomKey(event) = event { assert_eq!(&event.sender, alice.user_id()); @@ -1661,7 +1686,11 @@ pub(crate) mod test { content: to_device_requests_to_content(to_device_requests), }; - let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + let group_session = bob + .decrypt_to_device_event(&event) + .await + .unwrap() + .inbound_group_session; bob.store .save_inbound_group_sessions(&[group_session.unwrap()]) .await diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index ed57b893..0557129a 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -15,6 +15,7 @@ use matrix_sdk_common::events::ToDeviceEvent; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +use sha2::{Digest, Sha256}; use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, @@ -53,9 +54,10 @@ use olm_rs::{ use crate::{ error::{EventError, OlmResult, SessionCreationError}, + file_encryption::encode, identities::ReadOnlyDevice, requests::UploadSigningKeysRequest, - store::Store, + store::{Changes, Store}, OlmError, }; @@ -70,6 +72,43 @@ pub struct Account { pub(crate) store: Store, } +#[derive(Debug, Clone)] +pub enum SessionType { + New(Session), + Existing(Session), +} + +impl SessionType { + #[cfg(test)] + pub fn session(self) -> Session { + match self { + SessionType::New(s) => s, + SessionType::Existing(s) => s, + } + } +} + +#[derive(Debug, Clone)] +pub struct OlmDecryptionInfo { + pub session: SessionType, + pub message_hash: OlmMessageHash, + pub event: Raw, + pub signing_key: String, + pub sender_key: String, + pub inbound_group_session: Option, +} + +/// A hash of a succesfully decrypted Olm message. +/// +/// Can be used to check if a message has been replayed to us. +#[derive(Debug, Clone)] +pub struct OlmMessageHash { + /// The curve25519 key of the sender that sent us the Olm message. + pub sender_key: String, + /// The hash of the message. + pub hash: String, +} + impl Deref for Account { type Target = ReadOnlyAccount; @@ -82,7 +121,7 @@ impl Account { pub async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult<(Session, Raw, String, String)> { + ) -> OlmResult { debug!("Decrypting to-device event"); let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { @@ -103,23 +142,47 @@ impl Account { .try_into() .map_err(|_| EventError::UnsupportedOlmType)?; + let sha = Sha256::new() + .chain(&content.sender_key) + .chain(&[message_type]) + .chain(&ciphertext.body); + + let message_hash = OlmMessageHash { + sender_key: content.sender_key.clone(), + hash: encode(sha.finalize().as_slice()), + }; + // Create a OlmMessage from the ciphertext and the type. let message = OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone()) .map_err(|_| EventError::UnsupportedOlmType)?; // Decrypt the OlmMessage and get a Ruma event out of it. - let (session, decrypted_event, signing_key) = self + let (session, event, signing_key) = match self .decrypt_olm_message(&event.sender, &content.sender_key, message) - .await?; + .await + { + Ok(d) => d, + Err(OlmError::SessionWedged(user_id, sender_key)) => { + if self.store.is_message_known(&message_hash).await? { + return Err(OlmError::ReplayedMessage(user_id, sender_key)); + } else { + return Err(OlmError::SessionWedged(user_id, sender_key)); + } + } + Err(e) => return Err(e.into()), + }; - debug!("Decrypted a to-device event {:?}", decrypted_event); - Ok(( + debug!("Decrypted a to-device event {:?}", event); + + Ok(OlmDecryptionInfo { session, - decrypted_event, - content.sender_key.clone(), + message_hash, + event, signing_key, - )) + sender_key: content.sender_key.clone(), + inbound_group_session: None, + }) } else { warn!("Olm event doesn't contain a ciphertext for our key"); Err(EventError::MissingCiphertext.into()) @@ -227,7 +290,7 @@ impl Account { sender: &UserId, sender_key: &str, message: OlmMessage, - ) -> OlmResult<(Session, Raw, String)> { + ) -> OlmResult<(SessionType, Raw, String)> { // First try to decrypt using an existing session. let (session, plaintext) = if let Some(d) = self .try_decrypt_olm_message(sender, sender_key, &message) @@ -235,7 +298,7 @@ impl Account { { // Decryption succeeded, de-structure the session/plaintext out of // the Option. - d + (SessionType::Existing(d.0), d.1) } else { // Decryption failed with every known session, let's try to create a // new session. @@ -282,7 +345,7 @@ impl Account { // Decrypt our message, this shouldn't fail since we're using a // newly created Session. let plaintext = session.decrypt(message).await?; - (session, plaintext) + (SessionType::New(session), plaintext) }; trace!("Successfully decrypted a Olm message: {}", plaintext); @@ -293,7 +356,20 @@ impl Account { // We might created a new session but decryption might still // have failed, store it for the error case here, this is fine // since we don't expect this to happen often or at all. - self.store.save_sessions(&[session]).await?; + match session { + SessionType::New(s) => { + let changes = Changes { + account: Some(self.inner.clone()), + sessions: vec![s], + ..Default::default() + }; + self.store.save_changes(changes).await?; + } + SessionType::Existing(s) => { + self.store.save_sessions(&[s]).await?; + } + } + return Err(e); } }; diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index 116f9b82..2ec94b23 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -23,8 +23,8 @@ mod session; mod signing; mod utility; -pub(crate) use account::Account; -pub use account::{AccountPickle, PickledAccount, ReadOnlyAccount}; +pub(crate) use account::{Account, OlmDecryptionInfo, SessionType}; +pub use account::{AccountPickle, OlmMessageHash, PickledAccount, ReadOnlyAccount}; pub use group_sessions::{ EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession, diff --git a/matrix_sdk_crypto/src/olm/session.rs b/matrix_sdk_crypto/src/olm/session.rs index a68f8230..8d031f63 100644 --- a/matrix_sdk_crypto/src/olm/session.rs +++ b/matrix_sdk_crypto/src/olm/session.rs @@ -126,9 +126,7 @@ impl Session { "content": content, }); - let plaintext = serde_json::to_string(&payload) - .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload))); - + let plaintext = serde_json::to_string(&payload)?; let ciphertext = self.encrypt_helper(&plaintext).await.to_tuple(); let message_type = ciphertext.0; diff --git a/matrix_sdk_crypto/src/olm/signing/pk_signing.rs b/matrix_sdk_crypto/src/olm/signing/pk_signing.rs index f653ebad..fb552349 100644 --- a/matrix_sdk_crypto/src/olm/signing/pk_signing.rs +++ b/matrix_sdk_crypto/src/olm/signing/pk_signing.rs @@ -170,7 +170,6 @@ impl MasterSigning { } pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) { - // TODO create a borrowed version of a cross singing key. let subkey_wihtout_signatures = json!({ "user_id": subkey.user_id.clone(), "keys": subkey.keys.clone(), diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 952522e5..d9b3403f 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -40,6 +40,7 @@ pub struct MemoryStore { inbound_group_sessions: GroupSessionStore, tracked_users: Arc>, users_for_key_query: Arc>, + olm_hashes: Arc>>, devices: DeviceStore, identities: Arc>, values: Arc>, @@ -52,6 +53,7 @@ impl Default for MemoryStore { inbound_group_sessions: GroupSessionStore::new(), tracked_users: Arc::new(DashSet::new()), users_for_key_query: Arc::new(DashSet::new()), + olm_hashes: Arc::new(DashMap::new()), devices: DeviceStore::new(), identities: Arc::new(DashMap::new()), values: Arc::new(DashMap::new()), @@ -120,6 +122,13 @@ impl CryptoStore for MemoryStore { .insert(identity.user_id().to_owned(), identity.clone()); } + for hash in changes.message_hashes { + self.olm_hashes + .entry(hash.sender_key.to_owned()) + .or_insert_with(DashSet::new) + .insert(hash.hash.clone()); + } + Ok(()) } @@ -211,21 +220,25 @@ impl CryptoStore for MemoryStore { Ok(self.values.get(key).map(|v| v.to_owned())) } - async fn save_identity(&self, _: PrivateCrossSigningIdentity) -> Result<()> { - Ok(()) - } - async fn load_identity(&self) -> Result> { Ok(None) } + + async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { + Ok(self + .olm_hashes + .entry(message_hash.sender_key.to_owned()) + .or_insert_with(DashSet::new) + .contains(&message_hash.hash)) + } } #[cfg(test)] mod test { use crate::{ identities::device::test::get_device, - olm::{test::get_account_and_session, InboundGroupSession}, - store::{memorystore::MemoryStore, CryptoStore}, + olm::{test::get_account_and_session, InboundGroupSession, OlmMessageHash}, + store::{memorystore::MemoryStore, Changes, CryptoStore}, }; use matrix_sdk_common::identifiers::room_id; @@ -329,4 +342,21 @@ mod test { assert!(store.is_user_tracked(device.user_id())); } + + #[tokio::test] + async fn test_message_hash() { + let store = MemoryStore::new(); + + let hash = OlmMessageHash { + sender_key: "test_sender".to_owned(), + hash: "test_hash".to_owned(), + }; + + let mut changes = Changes::default(); + changes.message_hashes.push(hash.clone()); + + assert!(!store.is_message_known(&hash).await.unwrap()); + store.save_changes(changes).await.unwrap(); + assert!(store.is_message_known(&hash).await.unwrap()); + } } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 3fda11ba..9f7f7c4b 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -82,7 +82,9 @@ use matrix_sdk_common_macros::send_sync; use crate::{ error::SessionUnpicklingError, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, - olm::{InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session}, + olm::{ + InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, + }, verification::VerificationMachine, }; @@ -107,7 +109,9 @@ pub(crate) struct Store { #[allow(missing_docs)] pub struct Changes { pub account: Option, + pub private_identity: Option, pub sessions: Vec, + pub message_hashes: Vec, pub inbound_group_sessions: Vec, pub identities: IdentityChanges, pub devices: DeviceChanges, @@ -342,13 +346,14 @@ pub trait CryptoStore: Debug { /// * `account` - The account that should be stored. async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; - /// TODO - async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()>; - - /// TODO + /// Try to load a private cross signing identity, if one is stored. async fn load_identity(&self) -> Result>; - /// TODO + /// Save the set of changes to the store. + /// + /// # Arguments + /// + /// * `changes` - The set of changes that should be stored. async fn save_changes(&self, changes: Changes) -> Result<()>; /// Get all the sessions that belong to the given sender key. @@ -435,4 +440,7 @@ pub trait CryptoStore: Debug { /// Load a serializeable object from the store. async fn get_value(&self, key: &str) -> Result>; + + /// Check if a hash for an Olm message stored in the database. + async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; } diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index ae01352e..8b843c49 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -42,8 +42,9 @@ use crate::{ identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, olm::{ AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle, - PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession, PickledSession, - PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, SessionPickle, + OlmMessageHash, PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession, + PickledSession, PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, + SessionPickle, }, }; @@ -491,6 +492,24 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS olm_hashes ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "sender_key" TEXT NOT NULL, + "hash" TEXT NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id,sender_key,hash) + ); + + CREATE INDEX IF NOT EXISTS "olm_hashes_index" ON "olm_hashes" ("account_id"); + "#, + ) + .await?; + Ok(()) } @@ -1466,6 +1485,92 @@ impl SqliteStore { Ok(()) } + async fn save_olm_hashses( + &self, + connection: &mut SqliteConnection, + hashes: &[OlmMessageHash], + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + for hash in hashes { + query("REPLACE INTO olm_hashes (account_id, sender_key, hash) VALUES (?1, ?2, ?3)") + .bind(account_id) + .bind(&hash.sender_key) + .bind(&hash.hash) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + + async fn save_identity( + &self, + connection: &mut SqliteConnection, + identity: PrivateCrossSigningIdentity, + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let pickle = identity.pickle(self.get_pickle_key()).await?; + + query( + "INSERT INTO private_identities ( + account_id, user_id, pickle, shared + ) VALUES (?1, ?2, ?3, ?4) + ON CONFLICT(account_id, user_id) DO UPDATE SET + pickle = excluded.pickle, + shared = excluded.shared + ", + ) + .bind(account_id) + .bind(pickle.user_id.as_str()) + .bind(pickle.pickle) + .bind(pickle.shared) + .execute(&mut *connection) + .await?; + + Ok(()) + } + + async fn save_account_helper( + &self, + connection: &mut SqliteConnection, + account: ReadOnlyAccount, + ) -> Result<()> { + let pickle = account.pickle(self.get_pickle_mode()).await; + + query( + "INSERT INTO accounts ( + user_id, device_id, pickle, shared, uploaded_key_count + ) VALUES (?1, ?2, ?3, ?4, ?5) + ON CONFLICT(user_id, device_id) DO UPDATE SET + pickle = excluded.pickle, + shared = excluded.shared, + uploaded_key_count = excluded.uploaded_key_count + ", + ) + .bind(pickle.user_id.as_str()) + .bind(pickle.device_id.as_str()) + .bind(pickle.pickle.as_str()) + .bind(pickle.shared) + .bind(pickle.uploaded_signed_key_count) + .execute(&mut *connection) + .await?; + + let account_id: (i64,) = + query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") + .bind(self.user_id.as_str()) + .bind(self.device_id.as_str()) + .fetch_one(&mut *connection) + .await?; + + *self.account_info.lock().unwrap() = Some(AccountInfo { + account_id: account_id.0, + identity_keys: account.identity_keys.clone(), + }); + + Ok(()) + } + async fn save_user_helper( &self, mut connection: &mut SqliteConnection, @@ -1569,65 +1674,7 @@ impl CryptoStore for SqliteStore { async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { let mut connection = self.connection.lock().await; - let pickle = account.pickle(self.get_pickle_mode()).await; - - query( - "INSERT INTO accounts ( - user_id, device_id, pickle, shared, uploaded_key_count - ) VALUES (?1, ?2, ?3, ?4, ?5) - ON CONFLICT(user_id, device_id) DO UPDATE SET - pickle = excluded.pickle, - shared = excluded.shared, - uploaded_key_count = excluded.uploaded_key_count - ", - ) - .bind(pickle.user_id.as_str()) - .bind(pickle.device_id.as_str()) - .bind(pickle.pickle.as_str()) - .bind(pickle.shared) - .bind(pickle.uploaded_signed_key_count) - .execute(&mut *connection) - .await?; - - let account_id: (i64,) = - query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") - .bind(self.user_id.as_str()) - .bind(self.device_id.as_str()) - .fetch_one(&mut *connection) - .await?; - - *self.account_info.lock().unwrap() = Some(AccountInfo { - account_id: account_id.0, - identity_keys: account.identity_keys.clone(), - }); - - Ok(()) - } - - async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - let pickle = identity.pickle(self.get_pickle_key()).await?; - - let mut connection = self.connection.lock().await; - - query( - "INSERT INTO private_identities ( - account_id, user_id, pickle, shared - ) VALUES (?1, ?2, ?3, ?4) - ON CONFLICT(account_id, user_id) DO UPDATE SET - pickle = excluded.pickle, - shared = excluded.shared - ", - ) - .bind(account_id) - .bind(pickle.user_id.as_str()) - .bind(pickle.pickle) - .bind(pickle.shared) - .execute(&mut *connection) - .await?; - - Ok(()) + self.save_account_helper(&mut connection, account).await } async fn load_identity(&self) -> Result> { @@ -1664,6 +1711,14 @@ impl CryptoStore for SqliteStore { let mut connection = self.connection.lock().await; let mut transaction = connection.begin().await?; + if let Some(account) = changes.account { + self.save_account_helper(&mut transaction, account).await?; + } + + if let Some(identity) = changes.private_identity { + self.save_identity(&mut transaction, identity).await?; + } + self.save_sessions_helper(&mut transaction, &changes.sessions) .await?; self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions) @@ -1680,6 +1735,8 @@ impl CryptoStore for SqliteStore { .await?; self.save_user_identities(&mut transaction, &changes.identities.changed) .await?; + self.save_olm_hashses(&mut transaction, &changes.message_hashes) + .await?; transaction.commit().await?; @@ -1796,6 +1853,22 @@ impl CryptoStore for SqliteStore { Ok(row.map(|r| r.0)) } + + async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let row: Option<(String,)> = query_as( + "SELECT hash FROM olm_hashes WHERE account_id = ? and sender_key = ? and hash = ?", + ) + .bind(account_id) + .bind(&message_hash.sender_key) + .bind(&message_hash.hash) + .fetch_optional(&mut *connection) + .await?; + + Ok(row.is_some()) + } } #[cfg(not(tarpaulin_include))] @@ -1817,8 +1890,8 @@ mod test { user::test::{get_other_identity, get_own_identity}, }, olm::{ - GroupSessionKey, InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, - Session, + GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, + ReadOnlyAccount, Session, }, store::{Changes, DeviceChanges, IdentityChanges}, }; @@ -2352,7 +2425,12 @@ mod test { assert!(store.load_identity().await.unwrap().is_none()); let identity = PrivateCrossSigningIdentity::new((&*store.user_id).clone()).await; - store.save_identity(identity.clone()).await.unwrap(); + let changes = Changes { + private_identity: Some(identity.clone()), + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let loaded_identity = store.load_identity().await.unwrap().unwrap(); assert_eq!(identity.user_id(), loaded_identity.user_id()); } @@ -2371,4 +2449,21 @@ mod test { store.remove_value(&key).await.unwrap(); assert!(store.get_value(&key).await.unwrap().is_none()); } + + #[tokio::test(threaded_scheduler)] + async fn olm_hash_saving() { + let (_, store, _dir) = get_loaded_store().await; + + let hash = OlmMessageHash { + sender_key: "test_sender".to_owned(), + hash: "test_hash".to_owned(), + }; + + let mut changes = Changes::default(); + changes.message_hashes.push(hash.clone()); + + assert!(!store.is_message_known(&hash).await.unwrap()); + store.save_changes(changes).await.unwrap(); + assert!(store.is_message_known(&hash).await.unwrap()); + } } diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index 9c0ba30a..4065d2f7 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -117,9 +117,6 @@ impl Default for AcceptedProtocols { } } -// TODO implement expiration of the verification flow using the timeouts defined -// in the spec. - /// A type level state machine modeling the Sas flow. /// /// This is the generic struc holding common data between the different states