From abe13d7a2d3b55475d41d0a4554a98d378c5c2ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 26 Mar 2020 11:22:40 +0100 Subject: [PATCH] crypto: Make the session stores thread safe. --- src/crypto/machine.rs | 10 +++++---- src/crypto/memory_stores.rs | 44 ++++++++++++++++++++++++++++++++----- src/crypto/olm.rs | 3 +++ src/crypto/store/mod.rs | 16 ++------------ src/crypto/store/sqlite.rs | 7 ------ 5 files changed, 50 insertions(+), 30 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index f9e42693..917f8795 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -398,7 +398,7 @@ impl OlmMachine { sender_key: &str, message: &OlmMessage, ) -> Result> { - let mut s = self.store.sessions_mut(sender_key).await?; + let mut s = self.sessions.get_mut(sender_key).await; let sessions = if let Some(s) = s { s @@ -406,7 +406,7 @@ impl OlmMachine { return Ok(None); }; - for session in sessions.iter_mut() { + for session in sessions.lock().await.iter_mut() { let mut matches = false; if let OlmMessage::PreKey(m) = &message { @@ -448,8 +448,10 @@ impl OlmMachine { } }; - session.decrypt(message)? + let plaintext = session.decrypt(message)?; + self.sessions.add(session).await; // TODO save the session + plaintext }; // TODO convert the plaintext to a ruma event. @@ -644,7 +646,7 @@ impl OlmMachine { // TODO check if the olm session is wedged and re-request the key. let session = session.ok_or(OlmError::MissingSession)?; - let (plaintext, _) = session.decrypt(content.ciphertext.clone())?; + let (plaintext, _) = session.lock().await.decrypt(content.ciphertext.clone())?; // TODO check the message index. // TODO check if this is from a verified device. diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 6d0ac27c..e12004b9 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -12,12 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::olm::InboundGroupSession; use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::Mutex; + +use super::olm::{InboundGroupSession, Session}; + +#[derive(Debug)] +pub struct SessionStore { + entries: Mutex>>>>, +} + +impl SessionStore { + pub fn new() -> Self { + SessionStore { + entries: Mutex::new(HashMap::new()), + } + } + + pub async fn add(&mut self, session: Session) { + let mut entries = self.entries.lock().await; + + if !entries.contains_key(&session.sender_key) { + entries.insert( + session.sender_key.to_owned(), + Arc::new(Mutex::new(Vec::new())), + ); + } + let mut sessions = entries.get_mut(&session.sender_key).unwrap(); + sessions.lock().await.push(session); + } + + pub async fn get_mut(&mut self, sender_key: &str) -> Option>>> { + self.entries.lock().await.get_mut(sender_key).cloned() + } +} #[derive(Debug)] pub struct GroupSessionStore { - entries: HashMap>>, + entries: HashMap>>>>, } impl GroupSessionStore { @@ -40,7 +74,7 @@ impl GroupSessionStore { } let mut sender_map = room_map.get_mut(&session.sender_key).unwrap(); - let ret = sender_map.insert(session.session_id(), session); + let ret = sender_map.insert(session.session_id(), Arc::new(Mutex::new(session))); ret.is_some() } @@ -50,9 +84,9 @@ impl GroupSessionStore { room_id: &str, sender_key: &str, session_id: &str, - ) -> Option<&InboundGroupSession> { + ) -> Option>> { self.entries .get(room_id) - .and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id))) + .and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned())) } } diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index a70fbe21..1fa9da06 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -116,6 +116,7 @@ impl Account { Ok(Session { inner: session, + sender_key: their_identity_key.to_owned(), creation_time: now.clone(), last_use_time: now, }) @@ -128,8 +129,10 @@ impl PartialEq for Account { } } +#[derive(Debug)] pub struct Session { inner: OlmSession, + pub(crate) sender_key: String, creation_time: Instant, last_use_time: Instant, } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index 7ce34854..32e754f3 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -53,24 +53,19 @@ pub enum CryptoStoreError { pub type Result = std::result::Result; #[async_trait] -pub trait CryptoStore: Debug { +pub trait CryptoStore: Debug + Sync + Sync { async fn load_account(&mut self) -> Result>; async fn save_account(&mut self, account: Arc>) -> Result<()>; - async fn sessions_mut(&mut self, sender_key: &str) -> Result>>; } pub struct MemoryStore { pub(crate) account_info: Option<(String, bool)>, - sessions: HashMap>, } impl MemoryStore { /// Create a new empty memory store. pub fn new() -> Self { - MemoryStore { - account_info: None, - sessions: HashMap::new(), - } + MemoryStore { account_info: None } } } @@ -94,13 +89,6 @@ impl CryptoStore for MemoryStore { self.account_info = Some((pickle, acc.shared)); Ok(()) } - - async fn sessions_mut<'a>( - &'a mut self, - sender_key: &str, - ) -> Result>> { - Ok(self.sessions.get_mut(sender_key)) - } } impl std::fmt::Debug for MemoryStore { diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 09aa87c3..34140689 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -168,13 +168,6 @@ impl CryptoStore for SqliteStore { Ok(()) } - - async fn sessions_mut<'a>( - &'a mut self, - sender_key: &str, - ) -> Result>> { - todo!() - } } impl std::fmt::Debug for SqliteStore {