crypto: Move the sessions cache into the cryptostore.

master
Damir Jelić 2020-03-27 12:09:54 +01:00
parent fca8062da0
commit e4dcca550c
4 changed files with 87 additions and 53 deletions

View File

@ -404,7 +404,7 @@ impl OlmMachine {
sender_key: &str, sender_key: &str,
message: &OlmMessage, message: &OlmMessage,
) -> Result<Option<String>> { ) -> Result<Option<String>> {
let mut s = self.sessions.get_mut(sender_key).await; let s = self.sessions.get(sender_key);
let sessions = if let Some(s) = s { let sessions = if let Some(s) = s {
s s
@ -412,20 +412,24 @@ impl OlmMachine {
return Ok(None); return Ok(None);
}; };
for session in sessions.lock().await.iter_mut() { for session in sessions {
let mut matches = false; let mut matches = false;
let mut session_lock = session.lock().await;
if let OlmMessage::PreKey(m) = &message { if let OlmMessage::PreKey(m) = &message {
matches = session.matches(sender_key, m.clone())?; matches = session_lock.matches(sender_key, m.clone())?;
if !matches { if !matches {
continue; continue;
} }
} }
let ret = session.decrypt(message.clone()); let ret = session_lock.decrypt(message.clone());
if let Ok(p) = ret { if let Ok(p) = ret {
// TODO save the session. if let Some(store) = self.store.as_mut() {
store.save_session(session.clone()).await?;
}
return Ok(Some(p)); return Ok(Some(p));
} else { } else {
if matches { if matches {
@ -456,6 +460,7 @@ impl OlmMachine {
let plaintext = session.decrypt(message)?; let plaintext = session.decrypt(message)?;
self.sessions.add(session).await; self.sessions.add(session).await;
// TODO save the session // TODO save the session
plaintext plaintext
}; };

View File

@ -21,31 +21,31 @@ use super::olm::{InboundGroupSession, Session};
#[derive(Debug)] #[derive(Debug)]
pub struct SessionStore { pub struct SessionStore {
entries: Mutex<HashMap<String, Arc<Mutex<Vec<Session>>>>>, entries: HashMap<String, Vec<Arc<Mutex<Session>>>>,
} }
impl SessionStore { impl SessionStore {
pub fn new() -> Self { pub fn new() -> Self {
SessionStore { SessionStore {
entries: Mutex::new(HashMap::new()), entries: HashMap::new(),
} }
} }
pub async fn add(&mut self, session: Session) { pub async fn add(&mut self, session: Session) {
let mut entries = self.entries.lock().await; if !self.entries.contains_key(&session.sender_key) {
self.entries
if !entries.contains_key(&session.sender_key) { .insert(session.sender_key.to_owned(), Vec::new());
entries.insert(
session.sender_key.to_owned(),
Arc::new(Mutex::new(Vec::new())),
);
} }
let mut sessions = entries.get_mut(&session.sender_key).unwrap(); let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
sessions.lock().await.push(session); sessions.push(Arc::new(Mutex::new(session)));
} }
pub async fn get_mut(&mut self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> { pub fn get(&self, sender_key: &str) -> Option<&Vec<Arc<Mutex<Session>>>> {
self.entries.lock().await.get_mut(sender_key).cloned() self.entries.get(sender_key)
}
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Arc<Mutex<Session>>>) {
self.entries.insert(sender_key.to_owned(), sessions);
} }
} }

View File

@ -64,5 +64,6 @@ pub trait CryptoStore: Debug + Send + Sync {
async fn load_account(&mut self) -> Result<Option<Account>>; async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>; async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>; async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
async fn load_sessions(&mut self) -> Result<Vec<Session>>; async fn get_sessions(&mut self, sender_key: &str)
-> Result<Option<&Vec<Arc<Mutex<Session>>>>>;
} }

View File

@ -26,12 +26,14 @@ use tokio::sync::Mutex;
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::{Account, CryptoStore, CryptoStoreError, Result, Session}; use super::{Account, CryptoStore, CryptoStoreError, Result, Session};
use crate::crypto::memory_stores::SessionStore;
pub struct SqliteStore { pub struct SqliteStore {
user_id: Arc<String>, user_id: Arc<String>,
device_id: Arc<String>, device_id: Arc<String>,
account_id: Option<i64>, account_id: Option<i64>,
path: PathBuf, path: PathBuf,
sessions: SessionStore,
connection: Arc<Mutex<SqliteConnection>>, connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>, pickle_passphrase: Option<Zeroizing<String>>,
} }
@ -75,6 +77,7 @@ impl SqliteStore {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.to_owned()), device_id: Arc::new(device_id.to_owned()),
account_id: None, account_id: None,
sessions: SessionStore::new(),
path: path.as_ref().to_owned(), path: path.as_ref().to_owned(),
connection: Arc::new(Mutex::new(connection)), connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase, pickle_passphrase: passphrase,
@ -122,6 +125,60 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn get_sessions_for(
&mut self,
sender_key: &str,
) -> Result<Option<&Vec<Arc<Mutex<Session>>>>> {
let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions {
let sessions = self.load_session_for(sender_key).await?;
if !sessions.is_empty() {
self.sessions.set_for_sender(sender_key, sessions);
}
}
Ok(self.sessions.get(sender_key))
}
async fn load_session_for(&mut self, sender_key: &str) -> Result<Vec<Arc<Mutex<Session>>>> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let rows: Vec<(String, String, String, String)> = query_as(
"SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ? and sender_key = ?"
)
.bind(account_id)
.bind(sender_key)
.fetch_all(&mut *connection)
.await?;
let now = Instant::now();
Ok(rows
.iter()
.map(|row| {
let pickle = &row.0;
let sender_key = &row.1;
let creation_time = now
.checked_sub(serde_json::from_str::<Duration>(&row.2)?)
.ok_or(CryptoStoreError::SessionTimestampError)?;
let last_use_time = now
.checked_sub(serde_json::from_str::<Duration>(&row.3)?)
.ok_or(CryptoStoreError::SessionTimestampError)?;
Ok(Arc::new(Mutex::new(Session::from_pickle(
pickle.to_string(),
self.get_pickle_mode(),
sender_key.to_string(),
creation_time,
last_use_time,
)?)))
})
.collect::<Result<Vec<Arc<Mutex<Session>>>>>()?)
}
fn get_pickle_mode(&self) -> PicklingMode { fn get_pickle_mode(&self) -> PicklingMode {
match &self.pickle_passphrase { match &self.pickle_passphrase {
Some(p) => PicklingMode::Encrypted { Some(p) => PicklingMode::Encrypted {
@ -233,40 +290,11 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn load_sessions(&mut self) -> Result<Vec<Session>> { async fn get_sessions<'a>(
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; &'a mut self,
let mut connection = self.connection.lock().await; sender_key: &str,
) -> Result<Option<&'a Vec<Arc<Mutex<Session>>>>> {
let rows: Vec<(String, String, String, String)> = query_as( Ok(self.get_sessions_for(sender_key).await?)
"SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ?"
)
.bind(account_id)
.fetch_all(&mut *connection)
.await?;
let now = Instant::now();
Ok(rows
.iter()
.map(|row| {
let pickle = &row.0;
let sender_key = &row.1;
let creation_time = now
.checked_sub(serde_json::from_str::<Duration>(&row.2)?)
.ok_or(CryptoStoreError::SessionTimestampError)?;
let last_use_time = now
.checked_sub(serde_json::from_str::<Duration>(&row.3)?)
.ok_or(CryptoStoreError::SessionTimestampError)?;
Ok(Session::from_pickle(
pickle.to_string(),
self.get_pickle_mode(),
sender_key.to_string(),
creation_time,
last_use_time,
)?)
})
.collect::<Result<Vec<Session>>>()?)
} }
} }