crypto: Allow session to be saved in a batched way.

master
Damir Jelić 2020-04-30 12:08:38 +02:00
parent e33fd098bc
commit 5de32c025f
4 changed files with 42 additions and 36 deletions

View File

@ -358,7 +358,7 @@ impl OlmMachine {
} }
}; };
if let Err(e) = self.store.save_session(session).await { if let Err(e) = self.store.save_sessions(&[session]).await {
error!("Failed to store newly created Olm session {}", e); error!("Failed to store newly created Olm session {}", e);
continue; continue;
} }
@ -739,7 +739,7 @@ impl OlmMachine {
// Decryption was successful, save the new ratchet state of the // Decryption was successful, save the new ratchet state of the
// session that was used to decrypt the message. // session that was used to decrypt the message.
trace!("Saved the new session state for {}", sender); trace!("Saved the new session state for {}", sender);
self.store.save_session(session).await?; self.store.save_sessions(&[session]).await?;
} }
Ok(plaintext) Ok(plaintext)
@ -804,7 +804,7 @@ impl OlmMachine {
let plaintext = session.decrypt(message).await?; let plaintext = session.decrypt(message).await?;
// Save the new ratcheted state of the session. // Save the new ratcheted state of the session.
self.store.save_session(session).await?; self.store.save_sessions(&[session]).await?;
plaintext plaintext
}; };
@ -1055,7 +1055,7 @@ impl OlmMachine {
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload))); .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload)));
let ciphertext = session.encrypt(&plaintext).await.to_tuple(); let ciphertext = session.encrypt(&plaintext).await.to_tuple();
self.store.save_session(session).await?; self.store.save_sessions(&[session]).await?;
let message_type: usize = ciphertext.0.into(); let message_type: usize = ciphertext.0.into();

View File

@ -52,8 +52,11 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn save_session(&mut self, session: Session) -> Result<()> { async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
self.sessions.add(session).await; for session in sessions {
self.sessions.add(session.clone()).await;
}
Ok(()) Ok(())
} }
@ -125,7 +128,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none()); assert!(store.load_account().await.unwrap().is_none());
store.save_account(account).await.unwrap(); store.save_account(account).await.unwrap();
store.save_session(session.clone()).await.unwrap(); store.save_sessions(&[session.clone()]).await.unwrap();
let sessions = store let sessions = store
.get_sessions(&session.sender_key) .get_sessions(&session.sender_key)

View File

@ -75,12 +75,12 @@ pub trait CryptoStore: Debug + Send + Sync {
/// * `account` - The account that should be stored. /// * `account` - The account that should be stored.
async fn save_account(&mut self, account: Account) -> Result<()>; async fn save_account(&mut self, account: Account) -> Result<()>;
/// Save the given session in the store. /// Save the given sessions in the store.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `session` - The session that should be stored. /// * `session` - The sessions that should be stored.
async fn save_session(&mut self, session: Session) -> Result<()>; async fn save_sessions(&mut self, session: &[Session]) -> Result<()>;
/// Get all the sessions that belong to the given sender key. /// Get all the sessions that belong to the given sender key.
/// ///

View File

@ -527,7 +527,9 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn save_session(&mut self, session: Session) -> Result<()> { async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
// TODO turn this into a transaction
for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?; self.lazy_load_sessions(&session.sender_key).await?;
self.sessions.add(session.clone()).await; self.sessions.add(session.clone()).await;
@ -553,6 +555,7 @@ impl CryptoStore for SqliteStore {
.bind(&pickle) .bind(&pickle)
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;
}
Ok(()) Ok(())
} }
@ -806,14 +809,14 @@ mod test {
let (mut store, _dir) = get_store(None).await; let (mut store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
assert!(store.save_session(session.clone()).await.is_err()); assert!(store.save_sessions(&[session.clone()]).await.is_err());
store store
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.save_session(session).await.unwrap(); store.save_sessions(&[session]).await.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -824,7 +827,7 @@ mod test {
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.save_session(session.clone()).await.unwrap(); store.save_sessions(&[session.clone()]).await.unwrap();
let sessions = store let sessions = store
.load_sessions_for(&session.sender_key) .load_sessions_for(&session.sender_key)
@ -846,7 +849,7 @@ mod test {
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.save_session(session).await.unwrap(); store.save_sessions(&[session]).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
let sessions_lock = sessions.lock().await; let sessions_lock = sessions.lock().await;