diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 1fa9da06..7fc4dcb9 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -90,19 +90,38 @@ impl Account { self.inner.sign(string) } - pub fn pickle(&self, pickling_mode: PicklingMode) -> String { - self.inner.pickle(pickling_mode) + pub fn pickle(&self, pickle_mode: PicklingMode) -> String { + self.inner.pickle(pickle_mode) } pub fn from_pickle( pickle: String, - pickling_mode: PicklingMode, + pickle_mode: PicklingMode, shared: bool, ) -> Result { - let acc = OlmAccount::unpickle(pickle, pickling_mode)?; + let acc = OlmAccount::unpickle(pickle, pickle_mode)?; Ok(Account { inner: acc, shared }) } + pub fn create_outbound_session( + &self, + their_identity_key: &str, + their_one_time_key: &str, + ) -> Result { + let session = self + .inner + .create_outbound_session(their_identity_key, their_one_time_key)?; + + let now = Instant::now(); + + Ok(Session { + inner: session, + sender_key: their_identity_key.to_owned(), + creation_time: now.clone(), + last_use_time: now, + }) + } + pub fn create_inbound_session_from( &self, their_identity_key: &str, @@ -133,8 +152,8 @@ impl PartialEq for Account { pub struct Session { inner: OlmSession, pub(crate) sender_key: String, - creation_time: Instant, - last_use_time: Instant, + pub(crate) creation_time: Instant, + pub(crate) last_use_time: Instant, } impl Session { @@ -152,6 +171,36 @@ impl Session { self.inner .matches_inbound_session_from(their_identity_key, message) } + + pub fn session_id(&self) -> String { + self.inner.session_id() + } + + pub fn pickle(&self, pickle_mode: PicklingMode) -> String { + self.inner.pickle(pickle_mode) + } + + pub fn from_pickle( + pickle: String, + pickle_mode: PicklingMode, + sender_key: String, + creation_time: Instant, + last_use_time: Instant, + ) -> Result { + let session = OlmSession::unpickle(pickle, pickle_mode)?; + Ok(Session { + inner: session, + sender_key, + creation_time, + last_use_time, + }) + } +} + +impl PartialEq for Session { + fn eq(&self, other: &Self) -> bool { + self.session_id() == other.session_id() + } } #[derive(Debug)] diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index e7d031fc..405b5b7a 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use url::ParseError; use async_trait::async_trait; +use serde_json::Error as SerdeError; use thiserror::Error; use tokio::sync::Mutex; @@ -31,7 +32,7 @@ use olm_rs::PicklingMode; pub mod sqlite; #[cfg(feature = "sqlite-cryptostore")] -use sqlx::{sqlite::Sqlite, Error as SqlxError}; +use sqlx::Error as SqlxError; #[derive(Error, Debug)] pub enum CryptoStoreError { @@ -43,11 +44,17 @@ pub enum CryptoStoreError { OlmSessionError(#[from] OlmSessionError), #[error("URL can't be parsed")] UrlParse(#[from] ParseError), + #[error("error serializing data for the database")] + Serialization(#[from] SerdeError), + #[error("can't load session timestamps")] + SessionTimestampError, + #[error("can't save/load sessions or group sessions in the store before a account is stored")] + AccountUnset, // TODO flatten the SqlxError to make it easier for other store // implementations. #[cfg(feature = "sqlite-cryptostore")] #[error("database error")] - DatabaseError(#[from] SqlxError), + DatabaseError(#[from] SqlxError), } pub type Result = std::result::Result; @@ -56,4 +63,6 @@ pub type Result = std::result::Result; pub trait CryptoStore: Debug + Send + Sync { async fn load_account(&mut self) -> Result>; async fn save_account(&mut self, account: Arc>) -> Result<()>; + async fn save_session(&mut self, session: Arc>) -> Result<()>; + async fn load_sessions(&mut self) -> Result>; } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 34140689..c01c6d85 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -15,19 +15,22 @@ use std::path::{Path, PathBuf}; use std::result::Result as StdResult; use std::sync::Arc; +use std::time::{Duration, Instant}; use url::Url; use async_trait::async_trait; use olm_rs::PicklingMode; +use serde_json; use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection}; use tokio::sync::Mutex; use zeroize::Zeroizing; -use super::{Account, CryptoStore, Result, Session}; +use super::{Account, CryptoStore, CryptoStoreError, Result, Session}; pub struct SqliteStore { user_id: Arc, device_id: Arc, + account_id: Option, path: PathBuf, connection: Arc>, pickle_passphrase: Option>, @@ -71,6 +74,7 @@ impl SqliteStore { let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.to_owned()), + account_id: None, path: path.as_ref().to_owned(), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: passphrase, @@ -84,7 +88,7 @@ impl SqliteStore { connection .execute( r#" - CREATE TABLE IF NOT EXISTS account ( + CREATE TABLE IF NOT EXISTS accounts ( "id" INTEGER NOT NULL PRIMARY KEY, "user_id" TEXT NOT NULL, "device_id" TEXT NOT NULL, @@ -96,6 +100,25 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS sessions ( + "session_id" TEXT NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "creation_time" TEXT NOT NULL, + "last_use_time" TEXT NOT NULL, + "sender_key" TEXT NOT NULL, + "pickle" BLOB NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + ); + + CREATE INDEX "olmsessions_account_id" ON "sessions" ("account_id"); + "#, + ) + .await?; + Ok(()) } @@ -114,8 +137,8 @@ impl CryptoStore for SqliteStore { async fn load_account(&mut self) -> Result> { let mut connection = self.connection.lock().await; - let row: Option<(String, bool)> = query_as( - "SELECT pickle, shared FROM account + let row: Option<(i64, String, bool)> = query_as( + "SELECT id, pickle, shared FROM accounts WHERE user_id = ? and device_id = ?", ) .bind(&*self.user_id) @@ -124,11 +147,14 @@ impl CryptoStore for SqliteStore { .await?; let result = match row { - Some((pickle, shared)) => Some(Account::from_pickle( - pickle, - self.get_pickle_mode(), - shared, - )?), + Some((id, pickle, shared)) => { + self.account_id = Some(id); + Some(Account::from_pickle( + pickle, + self.get_pickle_mode(), + shared, + )?) + } None => None, }; @@ -141,7 +167,7 @@ impl CryptoStore for SqliteStore { let mut connection = self.connection.lock().await; query( - "INSERT OR IGNORE INTO account ( + "INSERT OR IGNORE INTO accounts ( user_id, device_id, pickle, shared ) VALUES (?, ?, ?, ?)", ) @@ -153,7 +179,7 @@ impl CryptoStore for SqliteStore { .await?; query( - "UPDATE account + "UPDATE accounts SET pickle = ?, shared = ? WHERE user_id = ? and @@ -166,8 +192,82 @@ impl CryptoStore for SqliteStore { .execute(&mut *connection) .await?; + let account_id: (i64,) = + query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") + .bind(&*self.user_id) + .bind(&*self.device_id) + .fetch_one(&mut *connection) + .await?; + + self.account_id = Some(account_id.0); + Ok(()) } + + async fn save_session(&mut self, session: Arc>) -> Result<()> { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + + let session = session.lock().await; + + let session_id = session.session_id(); + let creation_time = serde_json::to_string(&session.creation_time.elapsed())?; + let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?; + let pickle = session.pickle(self.get_pickle_mode()); + + let mut connection = self.connection.lock().await; + + query( + "REPLACE INTO sessions ( + session_id, account_id, creation_time, last_use_time, sender_key, pickle + ) VALUES (?, ?, ?, ?, ?, ?)", + ) + .bind(&session_id) + .bind(&account_id) + .bind(&creation_time) + .bind(&last_use_time) + .bind(&session.sender_key) + .bind(&pickle) + .execute(&mut *connection) + .await?; + + Ok(()) + } + + async fn load_sessions(&mut self) -> Result> { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let rows: Vec<(String, String, String, String)> = query_as( + "SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ?" + ) + .bind(account_id) + .fetch_all(&mut *connection) + .await?; + + let now = Instant::now(); + + Ok(rows + .iter() + .map(|row| { + let pickle = &row.0; + let sender_key = &row.1; + let creation_time = now + .checked_sub(serde_json::from_str::(&row.2)?) + .ok_or(CryptoStoreError::SessionTimestampError)?; + let last_use_time = now + .checked_sub(serde_json::from_str::(&row.3)?) + .ok_or(CryptoStoreError::SessionTimestampError)?; + + Ok(Session::from_pickle( + pickle.to_string(), + self.get_pickle_mode(), + sender_key.to_string(), + creation_time, + last_use_time, + )?) + }) + .collect::>>()?) + } } impl std::fmt::Debug for SqliteStore { @@ -186,7 +286,7 @@ mod test { use tempfile::tempdir; use tokio::sync::Mutex; - use super::{Account, CryptoStore, SqliteStore}; + use super::{Account, CryptoStore, Session, SqliteStore}; static USER_ID: &str = "@example:localhost"; static DEVICE_ID: &str = "DEVICEID"; @@ -204,6 +304,28 @@ mod test { Arc::new(Mutex::new(account)) } + fn get_account_and_session() -> (Arc>, Arc>) { + let alice = Account::new(); + + let bob = Account::new(); + + bob.generate_one_time_keys(1); + let one_time_key = bob + .one_time_keys() + .curve25519() + .iter() + .nth(0) + .unwrap() + .1 + .to_owned(); + let sender_key = bob.identity_keys().curve25519().to_owned(); + let session = alice + .create_outbound_session(&sender_key, &one_time_key) + .unwrap(); + + (Arc::new(Mutex::new(alice)), Arc::new(Mutex::new(session))) + } + #[tokio::test] async fn create_store() { let tmpdir = tempdir().unwrap(); @@ -264,4 +386,35 @@ mod test { assert_eq!(*acc, loaded_account); } + + #[tokio::test] + async fn save_session() { + let mut store = get_store().await; + let (account, session) = get_account_and_session(); + + assert!(store.save_session(session.clone()).await.is_err()); + + store + .save_account(account.clone()) + .await + .expect("Can't save account"); + + store.save_session(session).await.unwrap(); + } + + #[tokio::test] + async fn load_sessions() { + let mut store = get_store().await; + let (account, session) = get_account_and_session(); + store + .save_account(account.clone()) + .await + .expect("Can't save account"); + store.save_session(session.clone()).await.unwrap(); + + let sess = session.lock().await; + + let sessions = store.load_sessions().await.expect("Can't load sessions"); + assert!(sessions.contains(&sess)); + } }