From 25e60d398be4ffd9259a1bbd871a241e5e42f59a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 14 Apr 2020 14:05:18 +0200 Subject: [PATCH] crypto: Move the session mutex into the Session struct. --- src/crypto/machine.rs | 16 +++----- src/crypto/memory_stores.rs | 15 ++++--- src/crypto/olm.rs | 71 +++++++++++++++++++-------------- src/crypto/store/memorystore.rs | 7 +--- src/crypto/store/mod.rs | 7 +--- src/crypto/store/sqlite.rs | 39 +++++++----------- 6 files changed, 73 insertions(+), 82 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index d3fe2972..703f6836 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -18,7 +18,6 @@ use std::mem; #[cfg(feature = "sqlite-cryptostore")] use std::path::Path; use std::result::Result as StdResult; -use std::sync::Arc; use uuid::Uuid; use super::error::{OlmError, Result, SignatureError, VerificationResult}; @@ -34,7 +33,6 @@ use api::r0::keys; use cjson; use olm_rs::{session::OlmMessage, utility::OlmUtility}; use serde_json::{json, Value}; -use tokio::sync::Mutex; use tracing::{debug, error, info, instrument, trace, warn}; use ruma_client_api::r0::client_exchange::{ @@ -658,19 +656,17 @@ impl OlmMachine { return Ok(None); }; - for session in &*sessions.lock().await { + for session in &mut *sessions.lock().await { let mut matches = false; - let mut session_lock = session.lock().await; - if let OlmMessage::PreKey(m) = &message { - matches = session_lock.matches(sender_key, m.clone())?; + matches = session.matches(sender_key, m.clone()).await?; if !matches { continue; } } - let ret = session_lock.decrypt(message.clone()); + let ret = session.decrypt(message.clone()).await; if let Ok(p) = ret { self.store.save_session(session.clone()).await?; @@ -706,7 +702,7 @@ impl OlmMachine { } }; - let plaintext = session.decrypt(message)?; + let plaintext = session.decrypt(message).await?; self.store.add_and_save_session(session).await?; plaintext }; @@ -861,7 +857,7 @@ impl OlmMachine { async fn olm_encrypt( &mut self, - session: Arc>, + mut session: Session, recipient_device: &Device, event_type: EventType, content: Value, @@ -892,7 +888,7 @@ impl OlmMachine { let plaintext = cjson::to_string(&payload) .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload))); - let ciphertext = session.lock().await.encrypt(&plaintext).to_tuple(); + let ciphertext = session.encrypt(&plaintext).await.to_tuple(); self.store.save_session(session).await?; let message_type: usize = ciphertext.0.into(); diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 89704ad1..413e1e06 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -24,7 +24,7 @@ use crate::identifiers::{DeviceId, RoomId, UserId}; #[derive(Debug)] pub struct SessionStore { - entries: HashMap>>>>>, + entries: HashMap>>>, } impl SessionStore { @@ -34,25 +34,24 @@ impl SessionStore { } } - pub async fn add(&mut self, session: Session) -> Arc> { - if !self.entries.contains_key(&session.sender_key) { + pub async fn add(&mut self, session: Session) -> Session { + if !self.entries.contains_key(&*session.sender_key) { self.entries.insert( - session.sender_key.to_owned(), + session.sender_key.to_string(), Arc::new(Mutex::new(Vec::new())), ); } - let sessions = self.entries.get_mut(&session.sender_key).unwrap(); - let session = Arc::new(Mutex::new(session)); + let sessions = self.entries.get_mut(&*session.sender_key).unwrap(); sessions.lock().await.push(session.clone()); session } - pub fn get(&self, sender_key: &str) -> Option>>>>> { + pub fn get(&self, sender_key: &str) -> Option>>> { self.entries.get(sender_key).cloned() } - pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec>>) { + pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec) { self.entries .insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions))); } diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index bf5e69ea..1fed7697 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::fmt; +use std::mem; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Instant; @@ -170,12 +171,14 @@ impl Account { .create_outbound_session(their_identity_key, &their_one_time_key.key)?; let now = Instant::now(); + let session_id = session.session_id(); Ok(Session { - inner: session, - sender_key: their_identity_key.to_owned(), - creation_time: now.clone(), - last_use_time: now, + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(session_id), + sender_key: Arc::new(their_identity_key.to_owned()), + creation_time: Arc::new(now.clone()), + last_use_time: Arc::new(now), }) } @@ -209,12 +212,14 @@ impl Account { ); let now = Instant::now(); + let session_id = session.session_id(); Ok(Session { - inner: session, - sender_key: their_identity_key.to_owned(), - creation_time: now.clone(), - last_use_time: now, + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(session_id), + sender_key: Arc::new(their_identity_key.to_owned()), + creation_time: Arc::new(now.clone()), + last_use_time: Arc::new(now), }) } } @@ -225,16 +230,17 @@ impl PartialEq for Account { } } -#[derive(Debug)] /// The Olm Session. /// /// Sessions are used to exchange encrypted messages between two /// accounts/devices. +#[derive(Debug, Clone)] pub struct Session { - inner: OlmSession, - pub(crate) sender_key: String, - pub(crate) creation_time: Instant, - pub(crate) last_use_time: Instant, + inner: Arc>, + session_id: Arc, + pub(crate) sender_key: Arc, + pub(crate) creation_time: Arc, + pub(crate) last_use_time: Arc, } impl Session { @@ -246,9 +252,9 @@ impl Session { /// # Arguments /// /// * `message` - The Olm message that should be decrypted. - pub fn decrypt(&mut self, message: OlmMessage) -> Result { - let plaintext = self.inner.decrypt(message)?; - self.last_use_time = Instant::now(); + pub async fn decrypt(&mut self, message: OlmMessage) -> Result { + let plaintext = self.inner.lock().await.decrypt(message)?; + mem::replace(&mut self.last_use_time, Arc::new(Instant::now())); Ok(plaintext) } @@ -259,9 +265,9 @@ impl Session { /// # Arguments /// /// * `plaintext` - The plaintext that should be encrypted. - pub fn encrypt(&mut self, plaintext: &str) -> OlmMessage { - let message = self.inner.encrypt(plaintext); - self.last_use_time = Instant::now(); + pub async fn encrypt(&mut self, plaintext: &str) -> OlmMessage { + let message = self.inner.lock().await.encrypt(plaintext); + mem::replace(&mut self.last_use_time, Arc::new(Instant::now())); message } @@ -276,18 +282,20 @@ impl Session { /// that encrypted this Olm message. /// /// * `message` - The pre-key Olm message that should be checked. - pub fn matches( + pub async fn matches( &self, their_identity_key: &str, message: PreKeyMessage, ) -> Result { self.inner + .lock() + .await .matches_inbound_session_from(their_identity_key, message) } /// Returns the unique identifier for this session. - pub fn session_id(&self) -> String { - self.inner.session_id() + pub fn session_id(&self) -> &str { + &self.session_id } /// Store the session as a base64 encoded string. @@ -296,8 +304,8 @@ impl Session { /// /// * `pickle_mode` - The mode that was used to pickle the session, either /// an unencrypted mode or an encrypted using passphrase. - pub fn pickle(&self, pickle_mode: PicklingMode) -> String { - self.inner.pickle(pickle_mode) + pub async fn pickle(&self, pickle_mode: PicklingMode) -> String { + self.inner.lock().await.pickle(pickle_mode) } /// Restore a Session from a previously pickled string. @@ -328,11 +336,14 @@ impl Session { last_use_time: Instant, ) -> Result { let session = OlmSession::unpickle(pickle, pickle_mode)?; + let session_id = session.session_id(); + Ok(Session { - inner: session, - sender_key, - creation_time, - last_use_time, + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(session_id), + sender_key: Arc::new(sender_key), + creation_time: Arc::new(creation_time), + last_use_time: Arc::new(last_use_time), }) } } @@ -665,7 +676,7 @@ mod test { let plaintext = "Hello world"; - let message = bob_session.encrypt(plaintext); + let message = bob_session.encrypt(plaintext).await; let prekey_message = match message.clone() { OlmMessage::PreKey(m) => m, @@ -680,7 +691,7 @@ mod test { assert_eq!(bob_session.session_id(), alice_session.session_id()); - let decyrpted = alice_session.decrypt(message).unwrap(); + let decyrpted = alice_session.decrypt(message).await.unwrap(); assert_eq!(plaintext, decyrpted); } } diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index d14e29b7..6daa1dba 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -52,7 +52,7 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn save_session(&mut self, _: Arc>) -> Result<()> { + async fn save_session(&mut self, _: Session) -> Result<()> { Ok(()) } @@ -61,10 +61,7 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn get_sessions( - &mut self, - sender_key: &str, - ) -> Result>>>>>> { + async fn get_sessions(&mut self, sender_key: &str) -> Result>>>> { Ok(self.sessions.get(sender_key)) } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index 016cb334..d059b3a7 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -68,12 +68,9 @@ pub trait CryptoStore: Debug + Send + Sync { async fn load_account(&mut self) -> Result>; async fn save_account(&mut self, account: Account) -> Result<()>; - async fn save_session(&mut self, session: Arc>) -> Result<()>; + async fn save_session(&mut self, session: Session) -> Result<()>; async fn add_and_save_session(&mut self, session: Session) -> Result<()>; - async fn get_sessions( - &mut self, - sender_key: &str, - ) -> Result>>>>>>; + async fn get_sessions(&mut self, sender_key: &str) -> Result>>>>; async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result; async fn get_inbound_group_session( diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 1b57a832..cc4aea7c 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -155,7 +155,7 @@ impl SqliteStore { async fn get_sessions_for( &mut self, sender_key: &str, - ) -> Result>>>>>> { + ) -> Result>>>> { let loaded_sessions = self.sessions.get(sender_key).is_some(); if !loaded_sessions { @@ -169,7 +169,7 @@ impl SqliteStore { Ok(self.sessions.get(sender_key)) } - async fn load_sessions_for(&mut self, sender_key: &str) -> Result>>> { + async fn load_sessions_for(&mut self, sender_key: &str) -> Result> { let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -196,15 +196,15 @@ impl SqliteStore { .checked_sub(serde_json::from_str::(&row.3)?) .ok_or(CryptoStoreError::SessionTimestampError)?; - Ok(Arc::new(Mutex::new(Session::from_pickle( + Ok(Session::from_pickle( pickle.to_string(), self.get_pickle_mode(), sender_key.to_string(), creation_time, last_use_time, - )?))) + )?) }) - .collect::>>>>()?) + .collect::>>()?) } async fn load_inbound_group_sessions(&self) -> Result> { @@ -322,15 +322,13 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn save_session(&mut self, session: Arc>) -> Result<()> { + async fn save_session(&mut self, session: Session) -> Result<()> { let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; - let session = session.lock().await; - let session_id = session.session_id(); let creation_time = serde_json::to_string(&session.creation_time.elapsed())?; let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?; - let pickle = session.pickle(self.get_pickle_mode()); + let pickle = session.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; @@ -341,9 +339,9 @@ impl CryptoStore for SqliteStore { ) .bind(&session_id) .bind(&account_id) - .bind(&creation_time) - .bind(&last_use_time) - .bind(&session.sender_key) + .bind(&*creation_time) + .bind(&*last_use_time) + .bind(&*session.sender_key) .bind(&pickle) .execute(&mut *connection) .await?; @@ -357,10 +355,7 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn get_sessions( - &mut self, - sender_key: &str, - ) -> Result>>>>>> { + async fn get_sessions(&mut self, sender_key: &str) -> Result>>>> { Ok(self.get_sessions_for(sender_key).await?) } @@ -565,7 +560,6 @@ mod test { async fn save_session() { let mut store = get_store().await; let (account, session) = get_account_and_session().await; - let session = Arc::new(Mutex::new(session)); assert!(store.save_session(session.clone()).await.is_err()); @@ -581,22 +575,19 @@ mod test { async fn load_sessions() { let mut store = get_store().await; let (account, session) = get_account_and_session().await; - let session = Arc::new(Mutex::new(session)); store .save_account(account.clone()) .await .expect("Can't save account"); store.save_session(session.clone()).await.unwrap(); - let sess = session.lock().await; - let sessions = store - .load_sessions_for(&sess.sender_key) + .load_sessions_for(&session.sender_key) .await .expect("Can't load sessions"); let loaded_session = &sessions[0]; - assert_eq!(*sess, *loaded_session.lock().await); + assert_eq!(&session, loaded_session); } #[tokio::test] @@ -604,7 +595,7 @@ mod test { let mut store = get_store().await; let (account, session) = get_account_and_session().await; let sender_key = session.sender_key.to_owned(); - let session_id = session.session_id(); + let session_id = session.session_id().to_owned(); store .save_account(account.clone()) @@ -616,7 +607,7 @@ mod test { let sessions_lock = sessions.lock().await; let session = &sessions_lock[0]; - assert_eq!(session_id, *session.lock().await.session_id()); + assert_eq!(session_id, session.session_id()); } #[tokio::test]