crytpo: Implement session storing/loading for the sql store.
parent
7595cab178
commit
fca8062da0
|
@ -90,19 +90,38 @@ impl Account {
|
|||
self.inner.sign(string)
|
||||
}
|
||||
|
||||
pub fn pickle(&self, pickling_mode: PicklingMode) -> String {
|
||||
self.inner.pickle(pickling_mode)
|
||||
pub fn pickle(&self, pickle_mode: PicklingMode) -> String {
|
||||
self.inner.pickle(pickle_mode)
|
||||
}
|
||||
|
||||
pub fn from_pickle(
|
||||
pickle: String,
|
||||
pickling_mode: PicklingMode,
|
||||
pickle_mode: PicklingMode,
|
||||
shared: bool,
|
||||
) -> Result<Self, OlmAccountError> {
|
||||
let acc = OlmAccount::unpickle(pickle, pickling_mode)?;
|
||||
let acc = OlmAccount::unpickle(pickle, pickle_mode)?;
|
||||
Ok(Account { inner: acc, shared })
|
||||
}
|
||||
|
||||
pub fn create_outbound_session(
|
||||
&self,
|
||||
their_identity_key: &str,
|
||||
their_one_time_key: &str,
|
||||
) -> Result<Session, OlmSessionError> {
|
||||
let session = self
|
||||
.inner
|
||||
.create_outbound_session(their_identity_key, their_one_time_key)?;
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
Ok(Session {
|
||||
inner: session,
|
||||
sender_key: their_identity_key.to_owned(),
|
||||
creation_time: now.clone(),
|
||||
last_use_time: now,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_inbound_session_from(
|
||||
&self,
|
||||
their_identity_key: &str,
|
||||
|
@ -133,8 +152,8 @@ impl PartialEq for Account {
|
|||
pub struct Session {
|
||||
inner: OlmSession,
|
||||
pub(crate) sender_key: String,
|
||||
creation_time: Instant,
|
||||
last_use_time: Instant,
|
||||
pub(crate) creation_time: Instant,
|
||||
pub(crate) last_use_time: Instant,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
|
@ -152,6 +171,36 @@ impl Session {
|
|||
self.inner
|
||||
.matches_inbound_session_from(their_identity_key, message)
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> String {
|
||||
self.inner.session_id()
|
||||
}
|
||||
|
||||
pub fn pickle(&self, pickle_mode: PicklingMode) -> String {
|
||||
self.inner.pickle(pickle_mode)
|
||||
}
|
||||
|
||||
pub fn from_pickle(
|
||||
pickle: String,
|
||||
pickle_mode: PicklingMode,
|
||||
sender_key: String,
|
||||
creation_time: Instant,
|
||||
last_use_time: Instant,
|
||||
) -> Result<Self, OlmSessionError> {
|
||||
let session = OlmSession::unpickle(pickle, pickle_mode)?;
|
||||
Ok(Session {
|
||||
inner: session,
|
||||
sender_key,
|
||||
creation_time,
|
||||
last_use_time,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Session {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.session_id() == other.session_id()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -20,6 +20,7 @@ use std::sync::Arc;
|
|||
use url::ParseError;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Error as SerdeError;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
|
@ -31,7 +32,7 @@ use olm_rs::PicklingMode;
|
|||
pub mod sqlite;
|
||||
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
use sqlx::{sqlite::Sqlite, Error as SqlxError};
|
||||
use sqlx::Error as SqlxError;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CryptoStoreError {
|
||||
|
@ -43,11 +44,17 @@ pub enum CryptoStoreError {
|
|||
OlmSessionError(#[from] OlmSessionError),
|
||||
#[error("URL can't be parsed")]
|
||||
UrlParse(#[from] ParseError),
|
||||
#[error("error serializing data for the database")]
|
||||
Serialization(#[from] SerdeError),
|
||||
#[error("can't load session timestamps")]
|
||||
SessionTimestampError,
|
||||
#[error("can't save/load sessions or group sessions in the store before a account is stored")]
|
||||
AccountUnset,
|
||||
// TODO flatten the SqlxError to make it easier for other store
|
||||
// implementations.
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
#[error("database error")]
|
||||
DatabaseError(#[from] SqlxError<Sqlite>),
|
||||
DatabaseError(#[from] SqlxError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||
|
@ -56,4 +63,6 @@ pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
|||
pub trait CryptoStore: Debug + Send + Sync {
|
||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
|
||||
async fn load_sessions(&mut self) -> Result<Vec<Session>>;
|
||||
}
|
||||
|
|
|
@ -15,19 +15,22 @@
|
|||
use std::path::{Path, PathBuf};
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use url::Url;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use olm_rs::PicklingMode;
|
||||
use serde_json;
|
||||
use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection};
|
||||
use tokio::sync::Mutex;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::{Account, CryptoStore, Result, Session};
|
||||
use super::{Account, CryptoStore, CryptoStoreError, Result, Session};
|
||||
|
||||
pub struct SqliteStore {
|
||||
user_id: Arc<String>,
|
||||
device_id: Arc<String>,
|
||||
account_id: Option<i64>,
|
||||
path: PathBuf,
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
pickle_passphrase: Option<Zeroizing<String>>,
|
||||
|
@ -71,6 +74,7 @@ impl SqliteStore {
|
|||
let store = SqliteStore {
|
||||
user_id: Arc::new(user_id.to_owned()),
|
||||
device_id: Arc::new(device_id.to_owned()),
|
||||
account_id: None,
|
||||
path: path.as_ref().to_owned(),
|
||||
connection: Arc::new(Mutex::new(connection)),
|
||||
pickle_passphrase: passphrase,
|
||||
|
@ -84,7 +88,7 @@ impl SqliteStore {
|
|||
connection
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS account (
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY,
|
||||
"user_id" TEXT NOT NULL,
|
||||
"device_id" TEXT NOT NULL,
|
||||
|
@ -96,6 +100,25 @@ impl SqliteStore {
|
|||
)
|
||||
.await?;
|
||||
|
||||
connection
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
"session_id" TEXT NOT NULL PRIMARY KEY,
|
||||
"account_id" INTEGER NOT NULL,
|
||||
"creation_time" TEXT NOT NULL,
|
||||
"last_use_time" TEXT NOT NULL,
|
||||
"sender_key" TEXT NOT NULL,
|
||||
"pickle" BLOB NOT NULL,
|
||||
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX "olmsessions_account_id" ON "sessions" ("account_id");
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -114,8 +137,8 @@ impl CryptoStore for SqliteStore {
|
|||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
let row: Option<(String, bool)> = query_as(
|
||||
"SELECT pickle, shared FROM account
|
||||
let row: Option<(i64, String, bool)> = query_as(
|
||||
"SELECT id, pickle, shared FROM accounts
|
||||
WHERE user_id = ? and device_id = ?",
|
||||
)
|
||||
.bind(&*self.user_id)
|
||||
|
@ -124,11 +147,14 @@ impl CryptoStore for SqliteStore {
|
|||
.await?;
|
||||
|
||||
let result = match row {
|
||||
Some((pickle, shared)) => Some(Account::from_pickle(
|
||||
Some((id, pickle, shared)) => {
|
||||
self.account_id = Some(id);
|
||||
Some(Account::from_pickle(
|
||||
pickle,
|
||||
self.get_pickle_mode(),
|
||||
shared,
|
||||
)?),
|
||||
)?)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
|
@ -141,7 +167,7 @@ impl CryptoStore for SqliteStore {
|
|||
let mut connection = self.connection.lock().await;
|
||||
|
||||
query(
|
||||
"INSERT OR IGNORE INTO account (
|
||||
"INSERT OR IGNORE INTO accounts (
|
||||
user_id, device_id, pickle, shared
|
||||
) VALUES (?, ?, ?, ?)",
|
||||
)
|
||||
|
@ -153,7 +179,7 @@ impl CryptoStore for SqliteStore {
|
|||
.await?;
|
||||
|
||||
query(
|
||||
"UPDATE account
|
||||
"UPDATE accounts
|
||||
SET pickle = ?,
|
||||
shared = ?
|
||||
WHERE user_id = ? and
|
||||
|
@ -166,8 +192,82 @@ impl CryptoStore for SqliteStore {
|
|||
.execute(&mut *connection)
|
||||
.await?;
|
||||
|
||||
let account_id: (i64,) =
|
||||
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
|
||||
.bind(&*self.user_id)
|
||||
.bind(&*self.device_id)
|
||||
.fetch_one(&mut *connection)
|
||||
.await?;
|
||||
|
||||
self.account_id = Some(account_id.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()> {
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
|
||||
let session = session.lock().await;
|
||||
|
||||
let session_id = session.session_id();
|
||||
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
|
||||
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
|
||||
let pickle = session.pickle(self.get_pickle_mode());
|
||||
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
query(
|
||||
"REPLACE INTO sessions (
|
||||
session_id, account_id, creation_time, last_use_time, sender_key, pickle
|
||||
) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
)
|
||||
.bind(&session_id)
|
||||
.bind(&account_id)
|
||||
.bind(&creation_time)
|
||||
.bind(&last_use_time)
|
||||
.bind(&session.sender_key)
|
||||
.bind(&pickle)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_sessions(&mut self) -> Result<Vec<Session>> {
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
let rows: Vec<(String, String, String, String)> = query_as(
|
||||
"SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ?"
|
||||
)
|
||||
.bind(account_id)
|
||||
.fetch_all(&mut *connection)
|
||||
.await?;
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
Ok(rows
|
||||
.iter()
|
||||
.map(|row| {
|
||||
let pickle = &row.0;
|
||||
let sender_key = &row.1;
|
||||
let creation_time = now
|
||||
.checked_sub(serde_json::from_str::<Duration>(&row.2)?)
|
||||
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
||||
let last_use_time = now
|
||||
.checked_sub(serde_json::from_str::<Duration>(&row.3)?)
|
||||
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
||||
|
||||
Ok(Session::from_pickle(
|
||||
pickle.to_string(),
|
||||
self.get_pickle_mode(),
|
||||
sender_key.to_string(),
|
||||
creation_time,
|
||||
last_use_time,
|
||||
)?)
|
||||
})
|
||||
.collect::<Result<Vec<Session>>>()?)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SqliteStore {
|
||||
|
@ -186,7 +286,7 @@ mod test {
|
|||
use tempfile::tempdir;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::{Account, CryptoStore, SqliteStore};
|
||||
use super::{Account, CryptoStore, Session, SqliteStore};
|
||||
|
||||
static USER_ID: &str = "@example:localhost";
|
||||
static DEVICE_ID: &str = "DEVICEID";
|
||||
|
@ -204,6 +304,28 @@ mod test {
|
|||
Arc::new(Mutex::new(account))
|
||||
}
|
||||
|
||||
fn get_account_and_session() -> (Arc<Mutex<Account>>, Arc<Mutex<Session>>) {
|
||||
let alice = Account::new();
|
||||
|
||||
let bob = Account::new();
|
||||
|
||||
bob.generate_one_time_keys(1);
|
||||
let one_time_key = bob
|
||||
.one_time_keys()
|
||||
.curve25519()
|
||||
.iter()
|
||||
.nth(0)
|
||||
.unwrap()
|
||||
.1
|
||||
.to_owned();
|
||||
let sender_key = bob.identity_keys().curve25519().to_owned();
|
||||
let session = alice
|
||||
.create_outbound_session(&sender_key, &one_time_key)
|
||||
.unwrap();
|
||||
|
||||
(Arc::new(Mutex::new(alice)), Arc::new(Mutex::new(session)))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_store() {
|
||||
let tmpdir = tempdir().unwrap();
|
||||
|
@ -264,4 +386,35 @@ mod test {
|
|||
|
||||
assert_eq!(*acc, loaded_account);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_session() {
|
||||
let mut store = get_store().await;
|
||||
let (account, session) = get_account_and_session();
|
||||
|
||||
assert!(store.save_session(session.clone()).await.is_err());
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
|
||||
store.save_session(session).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_sessions() {
|
||||
let mut store = get_store().await;
|
||||
let (account, session) = get_account_and_session();
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_session(session.clone()).await.unwrap();
|
||||
|
||||
let sess = session.lock().await;
|
||||
|
||||
let sessions = store.load_sessions().await.expect("Can't load sessions");
|
||||
assert!(sessions.contains(&sess));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue