diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index 927d104b..63353cb2 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -46,11 +46,10 @@ base64 = "0.13.0" byteorder = "1.3.4" [target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx] -git = "https://github.com/launchbadge/sqlx/" -rev = "fd25a7530cf087e1529553ff854f192738db3461" +version = "0.4.1" optional = true default-features = false -features = ["runtime-tokio", "sqlite", "macros"] +features = ["runtime-tokio-native-tls", "sqlite", "macros"] [dev-dependencies] tokio = { version = "0.2.22", default-features = false, features = ["rt-threaded", "macros"] } diff --git a/matrix_sdk_crypto/src/error.rs b/matrix_sdk_crypto/src/error.rs index 6d9aacdc..24d68f1c 100644 --- a/matrix_sdk_crypto/src/error.rs +++ b/matrix_sdk_crypto/src/error.rs @@ -47,9 +47,16 @@ pub enum OlmError { Store(#[from] CryptoStoreError), /// The session with a device has become corrupted. - #[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")] + #[error( + "decryption failed likely because an Olm session from {0} with sender key {1} was wedged" + )] SessionWedged(UserId, String), + /// An Olm message got replayed while the Olm ratchet has already moved + /// forward. + #[error("decryption failed because an Olm message from {0} with sender key {1} was replayed")] + ReplayedMessage(UserId, String), + /// Encryption failed because the device does not have a valid Olm session /// with us. #[error( diff --git a/matrix_sdk_crypto/src/file_encryption/mod.rs b/matrix_sdk_crypto/src/file_encryption/mod.rs index 3c644bd9..45ccd310 100644 --- a/matrix_sdk_crypto/src/file_encryption/mod.rs +++ b/matrix_sdk_crypto/src/file_encryption/mod.rs @@ -14,10 +14,10 @@ fn decode_url_safe(input: impl AsRef<[u8]>) -> Result, DecodeError> { decode_config(input, URL_SAFE_NO_PAD) } -fn encode(input: impl AsRef<[u8]>) -> String { +pub fn encode(input: impl AsRef<[u8]>) -> String { encode_config(input, STANDARD_NO_PAD) } -fn encode_url_safe(input: impl AsRef<[u8]>) -> String { +pub fn encode_url_safe(input: impl AsRef<[u8]>) -> String { encode_config(input, URL_SAFE_NO_PAD) } diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 55f5b569..32cbd69f 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -1137,12 +1137,11 @@ mod test { .unwrap() .is_none()); - let (_, decrypted, sender_key, _) = - alice_account.decrypt_to_device_event(&event).await.unwrap(); + let decrypted = alice_account.decrypt_to_device_event(&event).await.unwrap(); - if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { + if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.event.deserialize().unwrap() { let (_, session) = alice_machine - .receive_forwarded_room_key(&sender_key, &mut e) + .receive_forwarded_room_key(&decrypted.sender_key, &mut e) .await .unwrap(); alice_machine @@ -1157,7 +1156,11 @@ mod test { // Check that alice now does have the session. let session = alice_machine .store - .get_inbound_group_session(&room_id(), &sender_key, group_session.session_id()) + .get_inbound_group_session( + &room_id(), + &decrypted.sender_key, + group_session.session_id(), + ) .await .unwrap() .unwrap(); @@ -1325,12 +1328,11 @@ mod test { .unwrap() .is_none()); - let (_, decrypted, sender_key, _) = - alice_account.decrypt_to_device_event(&event).await.unwrap(); + let decrypted = alice_account.decrypt_to_device_event(&event).await.unwrap(); - if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { + if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.event.deserialize().unwrap() { let (_, session) = alice_machine - .receive_forwarded_room_key(&sender_key, &mut e) + .receive_forwarded_room_key(&decrypted.sender_key, &mut e) .await .unwrap(); alice_machine @@ -1345,7 +1347,11 @@ mod test { // Check that alice now does have the session. let session = alice_machine .store - .get_inbound_group_session(&room_id(), &sender_key, group_session.session_id()) + .get_inbound_group_session( + &room_id(), + &decrypted.sender_key, + group_session.session_id(), + ) .await .unwrap() .unwrap(); diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 83b5939a..fdb93aee 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -52,7 +52,7 @@ use crate::{ key_request::KeyRequestMachine, olm::{ Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, - InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, + InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, ReadOnlyAccount, }, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, @@ -555,24 +555,23 @@ impl OlmMachine { async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult<(Session, Raw, Option)> { - let (session, decrypted_event, sender_key, signing_key) = - self.account.decrypt_to_device_event(event).await?; + ) -> OlmResult { + let mut decrypted = self.account.decrypt_to_device_event(event).await?; // Handle the decrypted event, e.g. fetch out Megolm sessions out of // the event. - if let (Some(event), group_session) = self - .handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event) - .await? + if let (Some(event), group_session) = + self.handle_decrypted_to_device_event(&decrypted).await? { // Some events may have sensitive data e.g. private keys, while we // want to notify our users that a private key was received we // don't want them to be able to do silly things with it. Handling // events modifies them and returns a modified one, so replace it // here if we get one. - Ok((session, event, group_session)) - } else { - Ok((session, decrypted_event, None)) + decrypted.event = event; + decrypted.inbound_group_session = group_session; } + + Ok(decrypted) } /// Create a group session from a room key and add it to our crypto store. @@ -704,27 +703,29 @@ impl OlmMachine { /// * `event` - The decrypted to-device event. async fn handle_decrypted_to_device_event( &self, - sender_key: &str, - signing_key: &str, - event: &Raw, + decrypted: &OlmDecryptionInfo, ) -> OlmResult<(Option>, Option)> { - let event = if let Ok(e) = event.deserialize() { - e - } else { - warn!("Decrypted to-device event failed to be parsed correctly"); - return Ok((None, None)); + let event = match decrypted.event.deserialize() { + Ok(e) => e, + Err(e) => { + warn!( + "Decrypted to-device event failed to be parsed correctly {:?}", + e + ); + return Ok((None, None)); + } }; match event { - AnyToDeviceEvent::RoomKey(mut e) => { - Ok(self.add_room_key(sender_key, signing_key, &mut e).await?) - } + AnyToDeviceEvent::RoomKey(mut e) => Ok(self + .add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e) + .await?), AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self .key_request_machine - .receive_forwarded_room_key(sender_key, &mut e) + .receive_forwarded_room_key(&decrypted.sender_key, &mut e) .await?), _ => { - warn!("Received a unexpected encrypted to-device event"); + warn!("Received an unexpected encrypted to-device event"); Ok((None, None)) } } @@ -808,38 +809,38 @@ impl OlmMachine { match &mut event { AnyToDeviceEvent::RoomEncrypted(e) => { - let (session, decrypted_event, group_session) = - match self.decrypt_to_device_event(e).await { - Ok(e) => e, - Err(err) => { - warn!( - "Failed to decrypt to-device event from {} {}", - e.sender, err - ); + let decrypted = match self.decrypt_to_device_event(e).await { + Ok(e) => e, + Err(err) => { + warn!( + "Failed to decrypt to-device event from {} {}", + e.sender, err + ); - if let OlmError::SessionWedged(sender, curve_key) = err { - if let Err(e) = self - .session_manager - .mark_device_as_wedged(&sender, &curve_key) - .await - { - error!( - "Couldn't mark device from {} to be unwedged {:?}", - sender, e - ); - } + if let OlmError::SessionWedged(sender, curve_key) = err { + if let Err(e) = self + .session_manager + .mark_device_as_wedged(&sender, &curve_key) + .await + { + error!( + "Couldn't mark device from {} to be unwedged {:?}", + sender, e + ); } - continue; } - }; + continue; + } + }; - changes.sessions.push(session); + changes.sessions.push(decrypted.session); + changes.message_hashes.push(decrypted.message_hash); - if let Some(group_session) = group_session { + if let Some(group_session) = decrypted.inbound_group_session { changes.inbound_group_sessions.push(group_session); } - *event_result = decrypted_event; + *event_result = decrypted.event; } AnyToDeviceEvent::RoomKeyRequest(e) => { self.key_request_machine.receive_incoming_key_request(e) @@ -1283,8 +1284,8 @@ pub(crate) mod test { content, }; - let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap(); - bob.store.save_sessions(&[session]).await.unwrap(); + let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); + bob.store.save_sessions(&[decrypted.session]).await.unwrap(); (alice, bob) } @@ -1578,7 +1579,7 @@ pub(crate) mod test { .decrypt_to_device_event(&event) .await .unwrap() - .1 + .event .deserialize() .unwrap(); @@ -1614,14 +1615,14 @@ pub(crate) mod test { .get_outbound_group_session(&room_id) .unwrap(); - let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); - bob.store.save_sessions(&[session]).await.unwrap(); + bob.store.save_sessions(&[decrypted.session]).await.unwrap(); bob.store - .save_inbound_group_sessions(&[group_session.unwrap()]) + .save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()]) .await .unwrap(); - let event = event.deserialize().unwrap(); + let event = decrypted.event.deserialize().unwrap(); if let AnyToDeviceEvent::RoomKey(event) = event { assert_eq!(&event.sender, alice.user_id()); @@ -1661,7 +1662,11 @@ pub(crate) mod test { content: to_device_requests_to_content(to_device_requests), }; - let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + let group_session = bob + .decrypt_to_device_event(&event) + .await + .unwrap() + .inbound_group_session; bob.store .save_inbound_group_sessions(&[group_session.unwrap()]) .await diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index ed57b893..1861c254 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -15,6 +15,7 @@ use matrix_sdk_common::events::ToDeviceEvent; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +use sha2::{Digest, Sha256}; use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, @@ -53,6 +54,7 @@ use olm_rs::{ use crate::{ error::{EventError, OlmResult, SessionCreationError}, + file_encryption::encode, identities::ReadOnlyDevice, requests::UploadSigningKeysRequest, store::Store, @@ -70,6 +72,27 @@ pub struct Account { pub(crate) store: Store, } +#[derive(Debug, Clone)] +pub struct OlmDecryptionInfo { + pub session: Session, + pub message_hash: OlmMessageHash, + pub event: Raw, + pub signing_key: String, + pub sender_key: String, + pub inbound_group_session: Option, +} + +/// A hash of a succesfully decrypted Olm message. +/// +/// Can be used to check if a message has been replayed to us. +#[derive(Debug, Clone)] +pub struct OlmMessageHash { + /// The curve25519 key of the sender that sent us the Olm message. + pub sender_key: String, + /// The hash of the message. + pub hash: String, +} + impl Deref for Account { type Target = ReadOnlyAccount; @@ -82,7 +105,7 @@ impl Account { pub async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult<(Session, Raw, String, String)> { + ) -> OlmResult { debug!("Decrypting to-device event"); let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { @@ -103,23 +126,47 @@ impl Account { .try_into() .map_err(|_| EventError::UnsupportedOlmType)?; + let sha = Sha256::new() + .chain(&content.sender_key) + .chain(&[message_type]) + .chain(&ciphertext.body); + + let message_hash = OlmMessageHash { + sender_key: content.sender_key.clone(), + hash: encode(sha.finalize().as_slice()), + }; + // Create a OlmMessage from the ciphertext and the type. let message = OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone()) .map_err(|_| EventError::UnsupportedOlmType)?; // Decrypt the OlmMessage and get a Ruma event out of it. - let (session, decrypted_event, signing_key) = self + let (session, event, signing_key) = match self .decrypt_olm_message(&event.sender, &content.sender_key, message) - .await?; + .await + { + Ok(d) => d, + Err(OlmError::SessionWedged(user_id, sender_key)) => { + if self.store.is_message_known(&message_hash).await? { + return Err(OlmError::ReplayedMessage(user_id, sender_key)); + } else { + return Err(OlmError::SessionWedged(user_id, sender_key)); + } + } + Err(e) => return Err(e.into()), + }; - debug!("Decrypted a to-device event {:?}", decrypted_event); - Ok(( + debug!("Decrypted a to-device event {:?}", event); + + Ok(OlmDecryptionInfo { session, - decrypted_event, - content.sender_key.clone(), + message_hash, + event, signing_key, - )) + sender_key: content.sender_key.clone(), + inbound_group_session: None, + }) } else { warn!("Olm event doesn't contain a ciphertext for our key"); Err(EventError::MissingCiphertext.into()) diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index 116f9b82..873d25b8 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -23,8 +23,8 @@ mod session; mod signing; mod utility; -pub(crate) use account::Account; -pub use account::{AccountPickle, PickledAccount, ReadOnlyAccount}; +pub(crate) use account::{Account, OlmDecryptionInfo}; +pub use account::{AccountPickle, OlmMessageHash, PickledAccount, ReadOnlyAccount}; pub use group_sessions::{ EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession, diff --git a/matrix_sdk_crypto/src/olm/session.rs b/matrix_sdk_crypto/src/olm/session.rs index a68f8230..8d031f63 100644 --- a/matrix_sdk_crypto/src/olm/session.rs +++ b/matrix_sdk_crypto/src/olm/session.rs @@ -126,9 +126,7 @@ impl Session { "content": content, }); - let plaintext = serde_json::to_string(&payload) - .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload))); - + let plaintext = serde_json::to_string(&payload)?; let ciphertext = self.encrypt_helper(&plaintext).await.to_tuple(); let message_type = ciphertext.0; diff --git a/matrix_sdk_crypto/src/olm/signing/pk_signing.rs b/matrix_sdk_crypto/src/olm/signing/pk_signing.rs index f653ebad..fb552349 100644 --- a/matrix_sdk_crypto/src/olm/signing/pk_signing.rs +++ b/matrix_sdk_crypto/src/olm/signing/pk_signing.rs @@ -170,7 +170,6 @@ impl MasterSigning { } pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) { - // TODO create a borrowed version of a cross singing key. let subkey_wihtout_signatures = json!({ "user_id": subkey.user_id.clone(), "keys": subkey.keys.clone(), diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 952522e5..b8edd074 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -40,6 +40,7 @@ pub struct MemoryStore { inbound_group_sessions: GroupSessionStore, tracked_users: Arc>, users_for_key_query: Arc>, + olm_hashes: Arc>>, devices: DeviceStore, identities: Arc>, values: Arc>, @@ -52,6 +53,7 @@ impl Default for MemoryStore { inbound_group_sessions: GroupSessionStore::new(), tracked_users: Arc::new(DashSet::new()), users_for_key_query: Arc::new(DashSet::new()), + olm_hashes: Arc::new(DashMap::new()), devices: DeviceStore::new(), identities: Arc::new(DashMap::new()), values: Arc::new(DashMap::new()), @@ -120,6 +122,13 @@ impl CryptoStore for MemoryStore { .insert(identity.user_id().to_owned(), identity.clone()); } + for hash in changes.message_hashes { + self.olm_hashes + .entry(hash.sender_key.to_owned()) + .or_insert_with(DashSet::new) + .insert(hash.hash.clone()); + } + Ok(()) } @@ -218,14 +227,22 @@ impl CryptoStore for MemoryStore { async fn load_identity(&self) -> Result> { Ok(None) } + + async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { + Ok(self + .olm_hashes + .entry(message_hash.sender_key.to_owned()) + .or_insert_with(DashSet::new) + .contains(&message_hash.hash)) + } } #[cfg(test)] mod test { use crate::{ identities::device::test::get_device, - olm::{test::get_account_and_session, InboundGroupSession}, - store::{memorystore::MemoryStore, CryptoStore}, + olm::{test::get_account_and_session, InboundGroupSession, OlmMessageHash}, + store::{memorystore::MemoryStore, Changes, CryptoStore}, }; use matrix_sdk_common::identifiers::room_id; @@ -329,4 +346,21 @@ mod test { assert!(store.is_user_tracked(device.user_id())); } + + #[tokio::test] + async fn test_message_hash() { + let store = MemoryStore::new(); + + let hash = OlmMessageHash { + sender_key: "test_sender".to_owned(), + hash: "test_hash".to_owned(), + }; + + let mut changes = Changes::default(); + changes.message_hashes.push(hash.clone()); + + assert!(!store.is_message_known(&hash).await.unwrap()); + store.save_changes(changes).await.unwrap(); + assert!(store.is_message_known(&hash).await.unwrap()); + } } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 3fda11ba..76ad72fb 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -82,7 +82,9 @@ use matrix_sdk_common_macros::send_sync; use crate::{ error::SessionUnpicklingError, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, - olm::{InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session}, + olm::{ + InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, + }, verification::VerificationMachine, }; @@ -108,6 +110,7 @@ pub(crate) struct Store { pub struct Changes { pub account: Option, pub sessions: Vec, + pub message_hashes: Vec, pub inbound_group_sessions: Vec, pub identities: IdentityChanges, pub devices: DeviceChanges, @@ -342,13 +345,22 @@ pub trait CryptoStore: Debug { /// * `account` - The account that should be stored. async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; - /// TODO + /// Save the given privat identity in the store. + /// + /// # Arguments + /// + /// * `identity` - The private cross signing identity that should be saved + /// in the store. async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()>; - /// TODO + /// Try to load a private cross signing identity, if one is stored. async fn load_identity(&self) -> Result>; - /// TODO + /// Save the set of changes to the store. + /// + /// # Arguments + /// + /// * `changes` - The set of changes that should be stored. async fn save_changes(&self, changes: Changes) -> Result<()>; /// Get all the sessions that belong to the given sender key. @@ -435,4 +447,7 @@ pub trait CryptoStore: Debug { /// Load a serializeable object from the store. async fn get_value(&self, key: &str) -> Result>; + + /// Check if a hash for an Olm message stored in the database. + async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; } diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index ae01352e..d41719fb 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -42,8 +42,9 @@ use crate::{ identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, olm::{ AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle, - PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession, PickledSession, - PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, SessionPickle, + OlmMessageHash, PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession, + PickledSession, PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, + SessionPickle, }, }; @@ -491,6 +492,24 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS olm_hashes ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "sender_key" TEXT NOT NULL, + "hash" TEXT NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id,sender_key,hash) + ); + + CREATE INDEX IF NOT EXISTS "olm_hashes_index" ON "olm_hashes" ("account_id"); + "#, + ) + .await?; + Ok(()) } @@ -1466,6 +1485,25 @@ impl SqliteStore { Ok(()) } + async fn save_olm_hashses( + &self, + connection: &mut SqliteConnection, + hashes: &[OlmMessageHash], + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + for hash in hashes { + query("REPLACE INTO olm_hashes (account_id, sender_key, hash) VALUES (?1, ?2, ?3)") + .bind(account_id) + .bind(&hash.sender_key) + .bind(&hash.hash) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + async fn save_user_helper( &self, mut connection: &mut SqliteConnection, @@ -1681,6 +1719,9 @@ impl CryptoStore for SqliteStore { self.save_user_identities(&mut transaction, &changes.identities.changed) .await?; + self.save_olm_hashses(&mut transaction, &changes.message_hashes) + .await?; + transaction.commit().await?; Ok(()) @@ -1796,6 +1837,22 @@ impl CryptoStore for SqliteStore { Ok(row.map(|r| r.0)) } + + async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let row: Option<(String,)> = query_as( + "SELECT hash FROM olm_hashes WHERE account_id = ? and sender_key = ? and hash = ?", + ) + .bind(account_id) + .bind(&message_hash.sender_key) + .bind(&message_hash.hash) + .fetch_optional(&mut *connection) + .await?; + + Ok(row.is_some()) + } } #[cfg(not(tarpaulin_include))] @@ -1817,8 +1874,8 @@ mod test { user::test::{get_other_identity, get_own_identity}, }, olm::{ - GroupSessionKey, InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, - Session, + GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, + ReadOnlyAccount, Session, }, store::{Changes, DeviceChanges, IdentityChanges}, }; @@ -2371,4 +2428,21 @@ mod test { store.remove_value(&key).await.unwrap(); assert!(store.get_value(&key).await.unwrap().is_none()); } + + #[tokio::test(threaded_scheduler)] + async fn olm_hash_saving() { + let (_, store, _dir) = get_loaded_store().await; + + let hash = OlmMessageHash { + sender_key: "test_sender".to_owned(), + hash: "test_hash".to_owned(), + }; + + let mut changes = Changes::default(); + changes.message_hashes.push(hash.clone()); + + assert!(!store.is_message_known(&hash).await.unwrap()); + store.save_changes(changes).await.unwrap(); + assert!(store.is_message_known(&hash).await.unwrap()); + } } diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index 9c0ba30a..4065d2f7 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -117,9 +117,6 @@ impl Default for AcceptedProtocols { } } -// TODO implement expiration of the verification flow using the timeouts defined -// in the spec. - /// A type level state machine modeling the Sas flow. /// /// This is the generic struc holding common data between the different states