From e4dcca550ca2b3412c252a7bf667ecc635c9c407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 27 Mar 2020 12:09:54 +0100 Subject: [PATCH] crypto: Move the sessions cache into the cryptostore. --- src/crypto/machine.rs | 15 ++++-- src/crypto/memory_stores.rs | 26 +++++----- src/crypto/store/mod.rs | 3 +- src/crypto/store/sqlite.rs | 96 ++++++++++++++++++++++++------------- 4 files changed, 87 insertions(+), 53 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 02edd5d0..95a2095d 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -404,7 +404,7 @@ impl OlmMachine { sender_key: &str, message: &OlmMessage, ) -> Result> { - let mut s = self.sessions.get_mut(sender_key).await; + let s = self.sessions.get(sender_key); let sessions = if let Some(s) = s { s @@ -412,20 +412,24 @@ impl OlmMachine { return Ok(None); }; - for session in sessions.lock().await.iter_mut() { + for session in sessions { let mut matches = false; + let mut session_lock = session.lock().await; + if let OlmMessage::PreKey(m) = &message { - matches = session.matches(sender_key, m.clone())?; + matches = session_lock.matches(sender_key, m.clone())?; if !matches { continue; } } - let ret = session.decrypt(message.clone()); + let ret = session_lock.decrypt(message.clone()); if let Ok(p) = ret { - // TODO save the session. + if let Some(store) = self.store.as_mut() { + store.save_session(session.clone()).await?; + } return Ok(Some(p)); } else { if matches { @@ -456,6 +460,7 @@ impl OlmMachine { let plaintext = session.decrypt(message)?; self.sessions.add(session).await; + // TODO save the session plaintext }; diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index e12004b9..56b68177 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -21,31 +21,31 @@ use super::olm::{InboundGroupSession, Session}; #[derive(Debug)] pub struct SessionStore { - entries: Mutex>>>>, + entries: HashMap>>>, } impl SessionStore { pub fn new() -> Self { SessionStore { - entries: Mutex::new(HashMap::new()), + entries: HashMap::new(), } } pub async fn add(&mut self, session: Session) { - let mut entries = self.entries.lock().await; - - if !entries.contains_key(&session.sender_key) { - entries.insert( - session.sender_key.to_owned(), - Arc::new(Mutex::new(Vec::new())), - ); + if !self.entries.contains_key(&session.sender_key) { + self.entries + .insert(session.sender_key.to_owned(), Vec::new()); } - let mut sessions = entries.get_mut(&session.sender_key).unwrap(); - sessions.lock().await.push(session); + let mut sessions = self.entries.get_mut(&session.sender_key).unwrap(); + sessions.push(Arc::new(Mutex::new(session))); } - pub async fn get_mut(&mut self, sender_key: &str) -> Option>>> { - self.entries.lock().await.get_mut(sender_key).cloned() + pub fn get(&self, sender_key: &str) -> Option<&Vec>>> { + self.entries.get(sender_key) + } + + pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec>>) { + self.entries.insert(sender_key.to_owned(), sessions); } } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index 405b5b7a..bb10378e 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -64,5 +64,6 @@ pub trait CryptoStore: Debug + Send + Sync { async fn load_account(&mut self) -> Result>; async fn save_account(&mut self, account: Arc>) -> Result<()>; async fn save_session(&mut self, session: Arc>) -> Result<()>; - async fn load_sessions(&mut self) -> Result>; + async fn get_sessions(&mut self, sender_key: &str) + -> Result>>>>; } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index c01c6d85..2e2af1a7 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -26,12 +26,14 @@ use tokio::sync::Mutex; use zeroize::Zeroizing; use super::{Account, CryptoStore, CryptoStoreError, Result, Session}; +use crate::crypto::memory_stores::SessionStore; pub struct SqliteStore { user_id: Arc, device_id: Arc, account_id: Option, path: PathBuf, + sessions: SessionStore, connection: Arc>, pickle_passphrase: Option>, } @@ -75,6 +77,7 @@ impl SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.to_owned()), account_id: None, + sessions: SessionStore::new(), path: path.as_ref().to_owned(), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: passphrase, @@ -122,6 +125,60 @@ impl SqliteStore { Ok(()) } + async fn get_sessions_for( + &mut self, + sender_key: &str, + ) -> Result>>>> { + let loaded_sessions = self.sessions.get(sender_key).is_some(); + + if !loaded_sessions { + let sessions = self.load_session_for(sender_key).await?; + + if !sessions.is_empty() { + self.sessions.set_for_sender(sender_key, sessions); + } + } + + Ok(self.sessions.get(sender_key)) + } + + async fn load_session_for(&mut self, sender_key: &str) -> Result>>> { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let rows: Vec<(String, String, String, String)> = query_as( + "SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ? and sender_key = ?" + ) + .bind(account_id) + .bind(sender_key) + .fetch_all(&mut *connection) + .await?; + + let now = Instant::now(); + + Ok(rows + .iter() + .map(|row| { + let pickle = &row.0; + let sender_key = &row.1; + let creation_time = now + .checked_sub(serde_json::from_str::(&row.2)?) + .ok_or(CryptoStoreError::SessionTimestampError)?; + let last_use_time = now + .checked_sub(serde_json::from_str::(&row.3)?) + .ok_or(CryptoStoreError::SessionTimestampError)?; + + Ok(Arc::new(Mutex::new(Session::from_pickle( + pickle.to_string(), + self.get_pickle_mode(), + sender_key.to_string(), + creation_time, + last_use_time, + )?))) + }) + .collect::>>>>()?) + } + fn get_pickle_mode(&self) -> PicklingMode { match &self.pickle_passphrase { Some(p) => PicklingMode::Encrypted { @@ -233,40 +290,11 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn load_sessions(&mut self) -> Result> { - let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let rows: Vec<(String, String, String, String)> = query_as( - "SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ?" - ) - .bind(account_id) - .fetch_all(&mut *connection) - .await?; - - let now = Instant::now(); - - Ok(rows - .iter() - .map(|row| { - let pickle = &row.0; - let sender_key = &row.1; - let creation_time = now - .checked_sub(serde_json::from_str::(&row.2)?) - .ok_or(CryptoStoreError::SessionTimestampError)?; - let last_use_time = now - .checked_sub(serde_json::from_str::(&row.3)?) - .ok_or(CryptoStoreError::SessionTimestampError)?; - - Ok(Session::from_pickle( - pickle.to_string(), - self.get_pickle_mode(), - sender_key.to_string(), - creation_time, - last_use_time, - )?) - }) - .collect::>>()?) + async fn get_sessions<'a>( + &'a mut self, + sender_key: &str, + ) -> Result>>>> { + Ok(self.get_sessions_for(sender_key).await?) } }