crypto: Move the session mutex into the Session struct.

master
Damir Jelić 2020-04-14 14:05:18 +02:00
parent 5f6cbbb193
commit 25e60d398b
6 changed files with 73 additions and 82 deletions

View File

@ -18,7 +18,6 @@ use std::mem;
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
use std::path::Path; use std::path::Path;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
use super::error::{OlmError, Result, SignatureError, VerificationResult}; use super::error::{OlmError, Result, SignatureError, VerificationResult};
@ -34,7 +33,6 @@ use api::r0::keys;
use cjson; use cjson;
use olm_rs::{session::OlmMessage, utility::OlmUtility}; use olm_rs::{session::OlmMessage, utility::OlmUtility};
use serde_json::{json, Value}; use serde_json::{json, Value};
use tokio::sync::Mutex;
use tracing::{debug, error, info, instrument, trace, warn}; use tracing::{debug, error, info, instrument, trace, warn};
use ruma_client_api::r0::client_exchange::{ use ruma_client_api::r0::client_exchange::{
@ -658,19 +656,17 @@ impl OlmMachine {
return Ok(None); return Ok(None);
}; };
for session in &*sessions.lock().await { for session in &mut *sessions.lock().await {
let mut matches = false; let mut matches = false;
let mut session_lock = session.lock().await;
if let OlmMessage::PreKey(m) = &message { if let OlmMessage::PreKey(m) = &message {
matches = session_lock.matches(sender_key, m.clone())?; matches = session.matches(sender_key, m.clone()).await?;
if !matches { if !matches {
continue; continue;
} }
} }
let ret = session_lock.decrypt(message.clone()); let ret = session.decrypt(message.clone()).await;
if let Ok(p) = ret { if let Ok(p) = ret {
self.store.save_session(session.clone()).await?; self.store.save_session(session.clone()).await?;
@ -706,7 +702,7 @@ impl OlmMachine {
} }
}; };
let plaintext = session.decrypt(message)?; let plaintext = session.decrypt(message).await?;
self.store.add_and_save_session(session).await?; self.store.add_and_save_session(session).await?;
plaintext plaintext
}; };
@ -861,7 +857,7 @@ impl OlmMachine {
async fn olm_encrypt( async fn olm_encrypt(
&mut self, &mut self,
session: Arc<Mutex<Session>>, mut session: Session,
recipient_device: &Device, recipient_device: &Device,
event_type: EventType, event_type: EventType,
content: Value, content: Value,
@ -892,7 +888,7 @@ impl OlmMachine {
let plaintext = cjson::to_string(&payload) let plaintext = cjson::to_string(&payload)
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload))); .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload)));
let ciphertext = session.lock().await.encrypt(&plaintext).to_tuple(); let ciphertext = session.encrypt(&plaintext).await.to_tuple();
self.store.save_session(session).await?; self.store.save_session(session).await?;
let message_type: usize = ciphertext.0.into(); let message_type: usize = ciphertext.0.into();

View File

@ -24,7 +24,7 @@ use crate::identifiers::{DeviceId, RoomId, UserId};
#[derive(Debug)] #[derive(Debug)]
pub struct SessionStore { pub struct SessionStore {
entries: HashMap<String, Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>, entries: HashMap<String, Arc<Mutex<Vec<Session>>>>,
} }
impl SessionStore { impl SessionStore {
@ -34,25 +34,24 @@ impl SessionStore {
} }
} }
pub async fn add(&mut self, session: Session) -> Arc<Mutex<Session>> { pub async fn add(&mut self, session: Session) -> Session {
if !self.entries.contains_key(&session.sender_key) { if !self.entries.contains_key(&*session.sender_key) {
self.entries.insert( self.entries.insert(
session.sender_key.to_owned(), session.sender_key.to_string(),
Arc::new(Mutex::new(Vec::new())), Arc::new(Mutex::new(Vec::new())),
); );
} }
let sessions = self.entries.get_mut(&session.sender_key).unwrap(); let sessions = self.entries.get_mut(&*session.sender_key).unwrap();
let session = Arc::new(Mutex::new(session));
sessions.lock().await.push(session.clone()); sessions.lock().await.push(session.clone());
session session
} }
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> { pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
self.entries.get(sender_key).cloned() self.entries.get(sender_key).cloned()
} }
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Arc<Mutex<Session>>>) { pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Session>) {
self.entries self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions))); .insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
} }

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use std::fmt; use std::fmt;
use std::mem;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -170,12 +171,14 @@ impl Account {
.create_outbound_session(their_identity_key, &their_one_time_key.key)?; .create_outbound_session(their_identity_key, &their_one_time_key.key)?;
let now = Instant::now(); let now = Instant::now();
let session_id = session.session_id();
Ok(Session { Ok(Session {
inner: session, inner: Arc::new(Mutex::new(session)),
sender_key: their_identity_key.to_owned(), session_id: Arc::new(session_id),
creation_time: now.clone(), sender_key: Arc::new(their_identity_key.to_owned()),
last_use_time: now, creation_time: Arc::new(now.clone()),
last_use_time: Arc::new(now),
}) })
} }
@ -209,12 +212,14 @@ impl Account {
); );
let now = Instant::now(); let now = Instant::now();
let session_id = session.session_id();
Ok(Session { Ok(Session {
inner: session, inner: Arc::new(Mutex::new(session)),
sender_key: their_identity_key.to_owned(), session_id: Arc::new(session_id),
creation_time: now.clone(), sender_key: Arc::new(their_identity_key.to_owned()),
last_use_time: now, creation_time: Arc::new(now.clone()),
last_use_time: Arc::new(now),
}) })
} }
} }
@ -225,16 +230,17 @@ impl PartialEq for Account {
} }
} }
#[derive(Debug)]
/// The Olm Session. /// The Olm Session.
/// ///
/// Sessions are used to exchange encrypted messages between two /// Sessions are used to exchange encrypted messages between two
/// accounts/devices. /// accounts/devices.
#[derive(Debug, Clone)]
pub struct Session { pub struct Session {
inner: OlmSession, inner: Arc<Mutex<OlmSession>>,
pub(crate) sender_key: String, session_id: Arc<String>,
pub(crate) creation_time: Instant, pub(crate) sender_key: Arc<String>,
pub(crate) last_use_time: Instant, pub(crate) creation_time: Arc<Instant>,
pub(crate) last_use_time: Arc<Instant>,
} }
impl Session { impl Session {
@ -246,9 +252,9 @@ impl Session {
/// # Arguments /// # Arguments
/// ///
/// * `message` - The Olm message that should be decrypted. /// * `message` - The Olm message that should be decrypted.
pub fn decrypt(&mut self, message: OlmMessage) -> Result<String, OlmSessionError> { pub async fn decrypt(&mut self, message: OlmMessage) -> Result<String, OlmSessionError> {
let plaintext = self.inner.decrypt(message)?; let plaintext = self.inner.lock().await.decrypt(message)?;
self.last_use_time = Instant::now(); mem::replace(&mut self.last_use_time, Arc::new(Instant::now()));
Ok(plaintext) Ok(plaintext)
} }
@ -259,9 +265,9 @@ impl Session {
/// # Arguments /// # Arguments
/// ///
/// * `plaintext` - The plaintext that should be encrypted. /// * `plaintext` - The plaintext that should be encrypted.
pub fn encrypt(&mut self, plaintext: &str) -> OlmMessage { pub async fn encrypt(&mut self, plaintext: &str) -> OlmMessage {
let message = self.inner.encrypt(plaintext); let message = self.inner.lock().await.encrypt(plaintext);
self.last_use_time = Instant::now(); mem::replace(&mut self.last_use_time, Arc::new(Instant::now()));
message message
} }
@ -276,18 +282,20 @@ impl Session {
/// that encrypted this Olm message. /// that encrypted this Olm message.
/// ///
/// * `message` - The pre-key Olm message that should be checked. /// * `message` - The pre-key Olm message that should be checked.
pub fn matches( pub async fn matches(
&self, &self,
their_identity_key: &str, their_identity_key: &str,
message: PreKeyMessage, message: PreKeyMessage,
) -> Result<bool, OlmSessionError> { ) -> Result<bool, OlmSessionError> {
self.inner self.inner
.lock()
.await
.matches_inbound_session_from(their_identity_key, message) .matches_inbound_session_from(their_identity_key, message)
} }
/// Returns the unique identifier for this session. /// Returns the unique identifier for this session.
pub fn session_id(&self) -> String { pub fn session_id(&self) -> &str {
self.inner.session_id() &self.session_id
} }
/// Store the session as a base64 encoded string. /// Store the session as a base64 encoded string.
@ -296,8 +304,8 @@ impl Session {
/// ///
/// * `pickle_mode` - The mode that was used to pickle the session, either /// * `pickle_mode` - The mode that was used to pickle the session, either
/// an unencrypted mode or an encrypted using passphrase. /// an unencrypted mode or an encrypted using passphrase.
pub fn pickle(&self, pickle_mode: PicklingMode) -> String { pub async fn pickle(&self, pickle_mode: PicklingMode) -> String {
self.inner.pickle(pickle_mode) self.inner.lock().await.pickle(pickle_mode)
} }
/// Restore a Session from a previously pickled string. /// Restore a Session from a previously pickled string.
@ -328,11 +336,14 @@ impl Session {
last_use_time: Instant, last_use_time: Instant,
) -> Result<Self, OlmSessionError> { ) -> Result<Self, OlmSessionError> {
let session = OlmSession::unpickle(pickle, pickle_mode)?; let session = OlmSession::unpickle(pickle, pickle_mode)?;
let session_id = session.session_id();
Ok(Session { Ok(Session {
inner: session, inner: Arc::new(Mutex::new(session)),
sender_key, session_id: Arc::new(session_id),
creation_time, sender_key: Arc::new(sender_key),
last_use_time, creation_time: Arc::new(creation_time),
last_use_time: Arc::new(last_use_time),
}) })
} }
} }
@ -665,7 +676,7 @@ mod test {
let plaintext = "Hello world"; let plaintext = "Hello world";
let message = bob_session.encrypt(plaintext); let message = bob_session.encrypt(plaintext).await;
let prekey_message = match message.clone() { let prekey_message = match message.clone() {
OlmMessage::PreKey(m) => m, OlmMessage::PreKey(m) => m,
@ -680,7 +691,7 @@ mod test {
assert_eq!(bob_session.session_id(), alice_session.session_id()); assert_eq!(bob_session.session_id(), alice_session.session_id());
let decyrpted = alice_session.decrypt(message).unwrap(); let decyrpted = alice_session.decrypt(message).await.unwrap();
assert_eq!(plaintext, decyrpted); assert_eq!(plaintext, decyrpted);
} }
} }

View File

@ -52,7 +52,7 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn save_session(&mut self, _: Arc<Mutex<Session>>) -> Result<()> { async fn save_session(&mut self, _: Session) -> Result<()> {
Ok(()) Ok(())
} }
@ -61,10 +61,7 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn get_sessions( async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }

View File

@ -68,12 +68,9 @@ pub trait CryptoStore: Debug + Send + Sync {
async fn load_account(&mut self) -> Result<Option<Account>>; async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Account) -> Result<()>; async fn save_account(&mut self, account: Account) -> Result<()>;
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>; async fn save_session(&mut self, session: Session) -> Result<()>;
async fn add_and_save_session(&mut self, session: Session) -> Result<()>; async fn add_and_save_session(&mut self, session: Session) -> Result<()>;
async fn get_sessions( async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>; async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
async fn get_inbound_group_session( async fn get_inbound_group_session(

View File

@ -155,7 +155,7 @@ impl SqliteStore {
async fn get_sessions_for( async fn get_sessions_for(
&mut self, &mut self,
sender_key: &str, sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> { ) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
let loaded_sessions = self.sessions.get(sender_key).is_some(); let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions { if !loaded_sessions {
@ -169,7 +169,7 @@ impl SqliteStore {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Arc<Mutex<Session>>>> { async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Session>> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -196,15 +196,15 @@ impl SqliteStore {
.checked_sub(serde_json::from_str::<Duration>(&row.3)?) .checked_sub(serde_json::from_str::<Duration>(&row.3)?)
.ok_or(CryptoStoreError::SessionTimestampError)?; .ok_or(CryptoStoreError::SessionTimestampError)?;
Ok(Arc::new(Mutex::new(Session::from_pickle( Ok(Session::from_pickle(
pickle.to_string(), pickle.to_string(),
self.get_pickle_mode(), self.get_pickle_mode(),
sender_key.to_string(), sender_key.to_string(),
creation_time, creation_time,
last_use_time, last_use_time,
)?))) )?)
}) })
.collect::<Result<Vec<Arc<Mutex<Session>>>>>()?) .collect::<Result<Vec<Session>>>()?)
} }
async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> { async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
@ -322,15 +322,13 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()> { async fn save_session(&mut self, session: Session) -> Result<()> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let session = session.lock().await;
let session_id = session.session_id(); let session_id = session.session_id();
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?; let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?; let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
let pickle = session.pickle(self.get_pickle_mode()); let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -341,9 +339,9 @@ impl CryptoStore for SqliteStore {
) )
.bind(&session_id) .bind(&session_id)
.bind(&account_id) .bind(&account_id)
.bind(&creation_time) .bind(&*creation_time)
.bind(&last_use_time) .bind(&*last_use_time)
.bind(&session.sender_key) .bind(&*session.sender_key)
.bind(&pickle) .bind(&pickle)
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;
@ -357,10 +355,7 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn get_sessions( async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
Ok(self.get_sessions_for(sender_key).await?) Ok(self.get_sessions_for(sender_key).await?)
} }
@ -565,7 +560,6 @@ mod test {
async fn save_session() { async fn save_session() {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
let session = Arc::new(Mutex::new(session));
assert!(store.save_session(session.clone()).await.is_err()); assert!(store.save_session(session.clone()).await.is_err());
@ -581,22 +575,19 @@ mod test {
async fn load_sessions() { async fn load_sessions() {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
let session = Arc::new(Mutex::new(session));
store store
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.save_session(session.clone()).await.unwrap(); store.save_session(session.clone()).await.unwrap();
let sess = session.lock().await;
let sessions = store let sessions = store
.load_sessions_for(&sess.sender_key) .load_sessions_for(&session.sender_key)
.await .await
.expect("Can't load sessions"); .expect("Can't load sessions");
let loaded_session = &sessions[0]; let loaded_session = &sessions[0];
assert_eq!(*sess, *loaded_session.lock().await); assert_eq!(&session, loaded_session);
} }
#[tokio::test] #[tokio::test]
@ -604,7 +595,7 @@ mod test {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
let sender_key = session.sender_key.to_owned(); let sender_key = session.sender_key.to_owned();
let session_id = session.session_id(); let session_id = session.session_id().to_owned();
store store
.save_account(account.clone()) .save_account(account.clone())
@ -616,7 +607,7 @@ mod test {
let sessions_lock = sessions.lock().await; let sessions_lock = sessions.lock().await;
let session = &sessions_lock[0]; let session = &sessions_lock[0];
assert_eq!(session_id, *session.lock().await.session_id()); assert_eq!(session_id, session.session_id());
} }
#[tokio::test] #[tokio::test]