diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index fdb93aee..9252e678 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -53,6 +53,7 @@ use crate::{ olm::{ Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, ReadOnlyAccount, + SessionType, }, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, @@ -365,10 +366,15 @@ impl OlmMachine { /// Mark the cross signing identity as shared. async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> { - self.user_identity.lock().await.mark_as_shared(); - self.store - .save_identity((&*self.user_identity.lock().await).clone()) - .await + let identity = self.user_identity.lock().await; + identity.mark_as_shared(); + + let changes = Changes { + private_identity: Some(identity.clone()), + ..Default::default() + }; + + self.store.save_changes(changes).await } /// Create a new cross signing identity and get the upload request to push @@ -400,11 +406,12 @@ impl OlmMachine { new: vec![public.into()], ..Default::default() }, + private_identity: Some(identity.clone()), ..Default::default() }; self.store.save_changes(changes).await?; - self.store.save_identity(identity.clone()).await?; + Ok((request, signature_request)) } else { info!("Trying to upload the existing cross signing identity"); @@ -833,7 +840,18 @@ impl OlmMachine { } }; - changes.sessions.push(decrypted.session); + // New sessions modify the account so we need to save that + // one as well. + match decrypted.session { + SessionType::New(s) => { + changes.sessions.push(s); + changes.account = Some(self.account.inner.clone()); + } + SessionType::Existing(s) => { + changes.sessions.push(s); + } + } + changes.message_hashes.push(decrypted.message_hash); if let Some(group_session) = decrypted.inbound_group_session { @@ -1285,7 +1303,10 @@ pub(crate) mod test { }; let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); - bob.store.save_sessions(&[decrypted.session]).await.unwrap(); + bob.store + .save_sessions(&[decrypted.session.session()]) + .await + .unwrap(); (alice, bob) } @@ -1617,7 +1638,10 @@ pub(crate) mod test { let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); - bob.store.save_sessions(&[decrypted.session]).await.unwrap(); + bob.store + .save_sessions(&[decrypted.session.session()]) + .await + .unwrap(); bob.store .save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()]) .await diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index 1861c254..0557129a 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -57,7 +57,7 @@ use crate::{ file_encryption::encode, identities::ReadOnlyDevice, requests::UploadSigningKeysRequest, - store::Store, + store::{Changes, Store}, OlmError, }; @@ -72,9 +72,25 @@ pub struct Account { pub(crate) store: Store, } +#[derive(Debug, Clone)] +pub enum SessionType { + New(Session), + Existing(Session), +} + +impl SessionType { + #[cfg(test)] + pub fn session(self) -> Session { + match self { + SessionType::New(s) => s, + SessionType::Existing(s) => s, + } + } +} + #[derive(Debug, Clone)] pub struct OlmDecryptionInfo { - pub session: Session, + pub session: SessionType, pub message_hash: OlmMessageHash, pub event: Raw, pub signing_key: String, @@ -274,7 +290,7 @@ impl Account { sender: &UserId, sender_key: &str, message: OlmMessage, - ) -> OlmResult<(Session, Raw, String)> { + ) -> OlmResult<(SessionType, Raw, String)> { // First try to decrypt using an existing session. let (session, plaintext) = if let Some(d) = self .try_decrypt_olm_message(sender, sender_key, &message) @@ -282,7 +298,7 @@ impl Account { { // Decryption succeeded, de-structure the session/plaintext out of // the Option. - d + (SessionType::Existing(d.0), d.1) } else { // Decryption failed with every known session, let's try to create a // new session. @@ -329,7 +345,7 @@ impl Account { // Decrypt our message, this shouldn't fail since we're using a // newly created Session. let plaintext = session.decrypt(message).await?; - (session, plaintext) + (SessionType::New(session), plaintext) }; trace!("Successfully decrypted a Olm message: {}", plaintext); @@ -340,7 +356,20 @@ impl Account { // We might created a new session but decryption might still // have failed, store it for the error case here, this is fine // since we don't expect this to happen often or at all. - self.store.save_sessions(&[session]).await?; + match session { + SessionType::New(s) => { + let changes = Changes { + account: Some(self.inner.clone()), + sessions: vec![s], + ..Default::default() + }; + self.store.save_changes(changes).await?; + } + SessionType::Existing(s) => { + self.store.save_sessions(&[s]).await?; + } + } + return Err(e); } }; diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index 873d25b8..2ec94b23 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -23,7 +23,7 @@ mod session; mod signing; mod utility; -pub(crate) use account::{Account, OlmDecryptionInfo}; +pub(crate) use account::{Account, OlmDecryptionInfo, SessionType}; pub use account::{AccountPickle, OlmMessageHash, PickledAccount, ReadOnlyAccount}; pub use group_sessions::{ EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle, diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index b8edd074..d9b3403f 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -220,10 +220,6 @@ impl CryptoStore for MemoryStore { Ok(self.values.get(key).map(|v| v.to_owned())) } - async fn save_identity(&self, _: PrivateCrossSigningIdentity) -> Result<()> { - Ok(()) - } - async fn load_identity(&self) -> Result> { Ok(None) } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 76ad72fb..9f7f7c4b 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -109,6 +109,7 @@ pub(crate) struct Store { #[allow(missing_docs)] pub struct Changes { pub account: Option, + pub private_identity: Option, pub sessions: Vec, pub message_hashes: Vec, pub inbound_group_sessions: Vec, @@ -345,14 +346,6 @@ pub trait CryptoStore: Debug { /// * `account` - The account that should be stored. async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; - /// Save the given privat identity in the store. - /// - /// # Arguments - /// - /// * `identity` - The private cross signing identity that should be saved - /// in the store. - async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()>; - /// Try to load a private cross signing identity, if one is stored. async fn load_identity(&self) -> Result>; diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index d41719fb..8b843c49 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -1504,6 +1504,73 @@ impl SqliteStore { Ok(()) } + async fn save_identity( + &self, + connection: &mut SqliteConnection, + identity: PrivateCrossSigningIdentity, + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let pickle = identity.pickle(self.get_pickle_key()).await?; + + query( + "INSERT INTO private_identities ( + account_id, user_id, pickle, shared + ) VALUES (?1, ?2, ?3, ?4) + ON CONFLICT(account_id, user_id) DO UPDATE SET + pickle = excluded.pickle, + shared = excluded.shared + ", + ) + .bind(account_id) + .bind(pickle.user_id.as_str()) + .bind(pickle.pickle) + .bind(pickle.shared) + .execute(&mut *connection) + .await?; + + Ok(()) + } + + async fn save_account_helper( + &self, + connection: &mut SqliteConnection, + account: ReadOnlyAccount, + ) -> Result<()> { + let pickle = account.pickle(self.get_pickle_mode()).await; + + query( + "INSERT INTO accounts ( + user_id, device_id, pickle, shared, uploaded_key_count + ) VALUES (?1, ?2, ?3, ?4, ?5) + ON CONFLICT(user_id, device_id) DO UPDATE SET + pickle = excluded.pickle, + shared = excluded.shared, + uploaded_key_count = excluded.uploaded_key_count + ", + ) + .bind(pickle.user_id.as_str()) + .bind(pickle.device_id.as_str()) + .bind(pickle.pickle.as_str()) + .bind(pickle.shared) + .bind(pickle.uploaded_signed_key_count) + .execute(&mut *connection) + .await?; + + let account_id: (i64,) = + query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") + .bind(self.user_id.as_str()) + .bind(self.device_id.as_str()) + .fetch_one(&mut *connection) + .await?; + + *self.account_info.lock().unwrap() = Some(AccountInfo { + account_id: account_id.0, + identity_keys: account.identity_keys.clone(), + }); + + Ok(()) + } + async fn save_user_helper( &self, mut connection: &mut SqliteConnection, @@ -1607,65 +1674,7 @@ impl CryptoStore for SqliteStore { async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { let mut connection = self.connection.lock().await; - let pickle = account.pickle(self.get_pickle_mode()).await; - - query( - "INSERT INTO accounts ( - user_id, device_id, pickle, shared, uploaded_key_count - ) VALUES (?1, ?2, ?3, ?4, ?5) - ON CONFLICT(user_id, device_id) DO UPDATE SET - pickle = excluded.pickle, - shared = excluded.shared, - uploaded_key_count = excluded.uploaded_key_count - ", - ) - .bind(pickle.user_id.as_str()) - .bind(pickle.device_id.as_str()) - .bind(pickle.pickle.as_str()) - .bind(pickle.shared) - .bind(pickle.uploaded_signed_key_count) - .execute(&mut *connection) - .await?; - - let account_id: (i64,) = - query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") - .bind(self.user_id.as_str()) - .bind(self.device_id.as_str()) - .fetch_one(&mut *connection) - .await?; - - *self.account_info.lock().unwrap() = Some(AccountInfo { - account_id: account_id.0, - identity_keys: account.identity_keys.clone(), - }); - - Ok(()) - } - - async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - let pickle = identity.pickle(self.get_pickle_key()).await?; - - let mut connection = self.connection.lock().await; - - query( - "INSERT INTO private_identities ( - account_id, user_id, pickle, shared - ) VALUES (?1, ?2, ?3, ?4) - ON CONFLICT(account_id, user_id) DO UPDATE SET - pickle = excluded.pickle, - shared = excluded.shared - ", - ) - .bind(account_id) - .bind(pickle.user_id.as_str()) - .bind(pickle.pickle) - .bind(pickle.shared) - .execute(&mut *connection) - .await?; - - Ok(()) + self.save_account_helper(&mut connection, account).await } async fn load_identity(&self) -> Result> { @@ -1702,6 +1711,14 @@ impl CryptoStore for SqliteStore { let mut connection = self.connection.lock().await; let mut transaction = connection.begin().await?; + if let Some(account) = changes.account { + self.save_account_helper(&mut transaction, account).await?; + } + + if let Some(identity) = changes.private_identity { + self.save_identity(&mut transaction, identity).await?; + } + self.save_sessions_helper(&mut transaction, &changes.sessions) .await?; self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions) @@ -1718,7 +1735,6 @@ impl CryptoStore for SqliteStore { .await?; self.save_user_identities(&mut transaction, &changes.identities.changed) .await?; - self.save_olm_hashses(&mut transaction, &changes.message_hashes) .await?; @@ -2409,7 +2425,12 @@ mod test { assert!(store.load_identity().await.unwrap().is_none()); let identity = PrivateCrossSigningIdentity::new((&*store.user_id).clone()).await; - store.save_identity(identity.clone()).await.unwrap(); + let changes = Changes { + private_identity: Some(identity.clone()), + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let loaded_identity = store.load_identity().await.unwrap().unwrap(); assert_eq!(identity.user_id(), loaded_identity.user_id()); }