diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index b14567aa..56c9c44c 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -358,7 +358,7 @@ impl OlmMachine { } }; - if let Err(e) = self.store.save_session(session).await { + if let Err(e) = self.store.save_sessions(&[session]).await { error!("Failed to store newly created Olm session {}", e); continue; } @@ -739,7 +739,7 @@ impl OlmMachine { // Decryption was successful, save the new ratchet state of the // session that was used to decrypt the message. trace!("Saved the new session state for {}", sender); - self.store.save_session(session).await?; + self.store.save_sessions(&[session]).await?; } Ok(plaintext) @@ -804,7 +804,7 @@ impl OlmMachine { let plaintext = session.decrypt(message).await?; // Save the new ratcheted state of the session. - self.store.save_session(session).await?; + self.store.save_sessions(&[session]).await?; plaintext }; @@ -1055,7 +1055,7 @@ impl OlmMachine { .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload))); let ciphertext = session.encrypt(&plaintext).await.to_tuple(); - self.store.save_session(session).await?; + self.store.save_sessions(&[session]).await?; let message_type: usize = ciphertext.0.into(); diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index c95145f7..0dcc0451 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -52,8 +52,11 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn save_session(&mut self, session: Session) -> Result<()> { - self.sessions.add(session).await; + async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> { + for session in sessions { + self.sessions.add(session.clone()).await; + } + Ok(()) } @@ -125,7 +128,7 @@ mod test { assert!(store.load_account().await.unwrap().is_none()); store.save_account(account).await.unwrap(); - store.save_session(session.clone()).await.unwrap(); + store.save_sessions(&[session.clone()]).await.unwrap(); let sessions = store .get_sessions(&session.sender_key) diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index d494418a..72ab61ed 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -75,12 +75,12 @@ pub trait CryptoStore: Debug + Send + Sync { /// * `account` - The account that should be stored. async fn save_account(&mut self, account: Account) -> Result<()>; - /// Save the given session in the store. + /// Save the given sessions in the store. /// /// # Arguments /// - /// * `session` - The session that should be stored. - async fn save_session(&mut self, session: Session) -> Result<()>; + /// * `session` - The sessions that should be stored. + async fn save_sessions(&mut self, session: &[Session]) -> Result<()>; /// Get all the sessions that belong to the given sender key. /// diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 9be64924..b806a8e1 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -527,32 +527,35 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn save_session(&mut self, session: Session) -> Result<()> { - self.lazy_load_sessions(&session.sender_key).await?; - self.sessions.add(session.clone()).await; + async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> { + // TODO turn this into a transaction + for session in sessions { + self.lazy_load_sessions(&session.sender_key).await?; + self.sessions.add(session.clone()).await; - let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; - let session_id = session.session_id(); - let creation_time = serde_json::to_string(&session.creation_time.elapsed())?; - let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?; - let pickle = session.pickle(self.get_pickle_mode()).await; + let session_id = session.session_id(); + let creation_time = serde_json::to_string(&session.creation_time.elapsed())?; + let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?; + let pickle = session.pickle(self.get_pickle_mode()).await; - let mut connection = self.connection.lock().await; + let mut connection = self.connection.lock().await; - query( - "REPLACE INTO sessions ( - session_id, account_id, creation_time, last_use_time, sender_key, pickle - ) VALUES (?, ?, ?, ?, ?, ?)", - ) - .bind(&session_id) - .bind(&account_id) - .bind(&*creation_time) - .bind(&*last_use_time) - .bind(&*session.sender_key) - .bind(&pickle) - .execute(&mut *connection) - .await?; + query( + "REPLACE INTO sessions ( + session_id, account_id, creation_time, last_use_time, sender_key, pickle + ) VALUES (?, ?, ?, ?, ?, ?)", + ) + .bind(&session_id) + .bind(&account_id) + .bind(&*creation_time) + .bind(&*last_use_time) + .bind(&*session.sender_key) + .bind(&pickle) + .execute(&mut *connection) + .await?; + } Ok(()) } @@ -806,14 +809,14 @@ mod test { let (mut store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; - assert!(store.save_session(session.clone()).await.is_err()); + assert!(store.save_sessions(&[session.clone()]).await.is_err()); store .save_account(account.clone()) .await .expect("Can't save account"); - store.save_session(session).await.unwrap(); + store.save_sessions(&[session]).await.unwrap(); } #[tokio::test] @@ -824,7 +827,7 @@ mod test { .save_account(account.clone()) .await .expect("Can't save account"); - store.save_session(session.clone()).await.unwrap(); + store.save_sessions(&[session.clone()]).await.unwrap(); let sessions = store .load_sessions_for(&session.sender_key) @@ -846,7 +849,7 @@ mod test { .save_account(account.clone()) .await .expect("Can't save account"); - store.save_session(session).await.unwrap(); + store.save_sessions(&[session]).await.unwrap(); let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); let sessions_lock = sessions.lock().await;