crypto: Move the sessions cache into the cryptostore.
parent
fca8062da0
commit
e4dcca550c
|
@ -404,7 +404,7 @@ impl OlmMachine {
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
message: &OlmMessage,
|
message: &OlmMessage,
|
||||||
) -> Result<Option<String>> {
|
) -> Result<Option<String>> {
|
||||||
let mut s = self.sessions.get_mut(sender_key).await;
|
let s = self.sessions.get(sender_key);
|
||||||
|
|
||||||
let sessions = if let Some(s) = s {
|
let sessions = if let Some(s) = s {
|
||||||
s
|
s
|
||||||
|
@ -412,20 +412,24 @@ impl OlmMachine {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
for session in sessions.lock().await.iter_mut() {
|
for session in sessions {
|
||||||
let mut matches = false;
|
let mut matches = false;
|
||||||
|
|
||||||
|
let mut session_lock = session.lock().await;
|
||||||
|
|
||||||
if let OlmMessage::PreKey(m) = &message {
|
if let OlmMessage::PreKey(m) = &message {
|
||||||
matches = session.matches(sender_key, m.clone())?;
|
matches = session_lock.matches(sender_key, m.clone())?;
|
||||||
if !matches {
|
if !matches {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let ret = session.decrypt(message.clone());
|
let ret = session_lock.decrypt(message.clone());
|
||||||
|
|
||||||
if let Ok(p) = ret {
|
if let Ok(p) = ret {
|
||||||
// TODO save the session.
|
if let Some(store) = self.store.as_mut() {
|
||||||
|
store.save_session(session.clone()).await?;
|
||||||
|
}
|
||||||
return Ok(Some(p));
|
return Ok(Some(p));
|
||||||
} else {
|
} else {
|
||||||
if matches {
|
if matches {
|
||||||
|
@ -456,6 +460,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
let plaintext = session.decrypt(message)?;
|
let plaintext = session.decrypt(message)?;
|
||||||
self.sessions.add(session).await;
|
self.sessions.add(session).await;
|
||||||
|
|
||||||
// TODO save the session
|
// TODO save the session
|
||||||
plaintext
|
plaintext
|
||||||
};
|
};
|
||||||
|
|
|
@ -21,31 +21,31 @@ use super::olm::{InboundGroupSession, Session};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SessionStore {
|
pub struct SessionStore {
|
||||||
entries: Mutex<HashMap<String, Arc<Mutex<Vec<Session>>>>>,
|
entries: HashMap<String, Vec<Arc<Mutex<Session>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionStore {
|
impl SessionStore {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
SessionStore {
|
SessionStore {
|
||||||
entries: Mutex::new(HashMap::new()),
|
entries: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add(&mut self, session: Session) {
|
pub async fn add(&mut self, session: Session) {
|
||||||
let mut entries = self.entries.lock().await;
|
if !self.entries.contains_key(&session.sender_key) {
|
||||||
|
self.entries
|
||||||
if !entries.contains_key(&session.sender_key) {
|
.insert(session.sender_key.to_owned(), Vec::new());
|
||||||
entries.insert(
|
|
||||||
session.sender_key.to_owned(),
|
|
||||||
Arc::new(Mutex::new(Vec::new())),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
let mut sessions = entries.get_mut(&session.sender_key).unwrap();
|
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
|
||||||
sessions.lock().await.push(session);
|
sessions.push(Arc::new(Mutex::new(session)));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_mut(&mut self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
|
pub fn get(&self, sender_key: &str) -> Option<&Vec<Arc<Mutex<Session>>>> {
|
||||||
self.entries.lock().await.get_mut(sender_key).cloned()
|
self.entries.get(sender_key)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Arc<Mutex<Session>>>) {
|
||||||
|
self.entries.insert(sender_key.to_owned(), sessions);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -64,5 +64,6 @@ pub trait CryptoStore: Debug + Send + Sync {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||||
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
|
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
|
||||||
async fn load_sessions(&mut self) -> Result<Vec<Session>>;
|
async fn get_sessions(&mut self, sender_key: &str)
|
||||||
|
-> Result<Option<&Vec<Arc<Mutex<Session>>>>>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,12 +26,14 @@ use tokio::sync::Mutex;
|
||||||
use zeroize::Zeroizing;
|
use zeroize::Zeroizing;
|
||||||
|
|
||||||
use super::{Account, CryptoStore, CryptoStoreError, Result, Session};
|
use super::{Account, CryptoStore, CryptoStoreError, Result, Session};
|
||||||
|
use crate::crypto::memory_stores::SessionStore;
|
||||||
|
|
||||||
pub struct SqliteStore {
|
pub struct SqliteStore {
|
||||||
user_id: Arc<String>,
|
user_id: Arc<String>,
|
||||||
device_id: Arc<String>,
|
device_id: Arc<String>,
|
||||||
account_id: Option<i64>,
|
account_id: Option<i64>,
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
|
sessions: SessionStore,
|
||||||
connection: Arc<Mutex<SqliteConnection>>,
|
connection: Arc<Mutex<SqliteConnection>>,
|
||||||
pickle_passphrase: Option<Zeroizing<String>>,
|
pickle_passphrase: Option<Zeroizing<String>>,
|
||||||
}
|
}
|
||||||
|
@ -75,6 +77,7 @@ impl SqliteStore {
|
||||||
user_id: Arc::new(user_id.to_owned()),
|
user_id: Arc::new(user_id.to_owned()),
|
||||||
device_id: Arc::new(device_id.to_owned()),
|
device_id: Arc::new(device_id.to_owned()),
|
||||||
account_id: None,
|
account_id: None,
|
||||||
|
sessions: SessionStore::new(),
|
||||||
path: path.as_ref().to_owned(),
|
path: path.as_ref().to_owned(),
|
||||||
connection: Arc::new(Mutex::new(connection)),
|
connection: Arc::new(Mutex::new(connection)),
|
||||||
pickle_passphrase: passphrase,
|
pickle_passphrase: passphrase,
|
||||||
|
@ -122,6 +125,60 @@ impl SqliteStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_sessions_for(
|
||||||
|
&mut self,
|
||||||
|
sender_key: &str,
|
||||||
|
) -> Result<Option<&Vec<Arc<Mutex<Session>>>>> {
|
||||||
|
let loaded_sessions = self.sessions.get(sender_key).is_some();
|
||||||
|
|
||||||
|
if !loaded_sessions {
|
||||||
|
let sessions = self.load_session_for(sender_key).await?;
|
||||||
|
|
||||||
|
if !sessions.is_empty() {
|
||||||
|
self.sessions.set_for_sender(sender_key, sessions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.sessions.get(sender_key))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_session_for(&mut self, sender_key: &str) -> Result<Vec<Arc<Mutex<Session>>>> {
|
||||||
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
let rows: Vec<(String, String, String, String)> = query_as(
|
||||||
|
"SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ? and sender_key = ?"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.bind(sender_key)
|
||||||
|
.fetch_all(&mut *connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
Ok(rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| {
|
||||||
|
let pickle = &row.0;
|
||||||
|
let sender_key = &row.1;
|
||||||
|
let creation_time = now
|
||||||
|
.checked_sub(serde_json::from_str::<Duration>(&row.2)?)
|
||||||
|
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
||||||
|
let last_use_time = now
|
||||||
|
.checked_sub(serde_json::from_str::<Duration>(&row.3)?)
|
||||||
|
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
||||||
|
|
||||||
|
Ok(Arc::new(Mutex::new(Session::from_pickle(
|
||||||
|
pickle.to_string(),
|
||||||
|
self.get_pickle_mode(),
|
||||||
|
sender_key.to_string(),
|
||||||
|
creation_time,
|
||||||
|
last_use_time,
|
||||||
|
)?)))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<Arc<Mutex<Session>>>>>()?)
|
||||||
|
}
|
||||||
|
|
||||||
fn get_pickle_mode(&self) -> PicklingMode {
|
fn get_pickle_mode(&self) -> PicklingMode {
|
||||||
match &self.pickle_passphrase {
|
match &self.pickle_passphrase {
|
||||||
Some(p) => PicklingMode::Encrypted {
|
Some(p) => PicklingMode::Encrypted {
|
||||||
|
@ -233,40 +290,11 @@ impl CryptoStore for SqliteStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_sessions(&mut self) -> Result<Vec<Session>> {
|
async fn get_sessions<'a>(
|
||||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
&'a mut self,
|
||||||
let mut connection = self.connection.lock().await;
|
sender_key: &str,
|
||||||
|
) -> Result<Option<&'a Vec<Arc<Mutex<Session>>>>> {
|
||||||
let rows: Vec<(String, String, String, String)> = query_as(
|
Ok(self.get_sessions_for(sender_key).await?)
|
||||||
"SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ?"
|
|
||||||
)
|
|
||||||
.bind(account_id)
|
|
||||||
.fetch_all(&mut *connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let now = Instant::now();
|
|
||||||
|
|
||||||
Ok(rows
|
|
||||||
.iter()
|
|
||||||
.map(|row| {
|
|
||||||
let pickle = &row.0;
|
|
||||||
let sender_key = &row.1;
|
|
||||||
let creation_time = now
|
|
||||||
.checked_sub(serde_json::from_str::<Duration>(&row.2)?)
|
|
||||||
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
|
||||||
let last_use_time = now
|
|
||||||
.checked_sub(serde_json::from_str::<Duration>(&row.3)?)
|
|
||||||
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
|
||||||
|
|
||||||
Ok(Session::from_pickle(
|
|
||||||
pickle.to_string(),
|
|
||||||
self.get_pickle_mode(),
|
|
||||||
sender_key.to_string(),
|
|
||||||
creation_time,
|
|
||||||
last_use_time,
|
|
||||||
)?)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<Session>>>()?)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue