crytpo: Implement session storing/loading for the sql store.
parent
7595cab178
commit
fca8062da0
|
@ -90,19 +90,38 @@ impl Account {
|
||||||
self.inner.sign(string)
|
self.inner.sign(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pickle(&self, pickling_mode: PicklingMode) -> String {
|
pub fn pickle(&self, pickle_mode: PicklingMode) -> String {
|
||||||
self.inner.pickle(pickling_mode)
|
self.inner.pickle(pickle_mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_pickle(
|
pub fn from_pickle(
|
||||||
pickle: String,
|
pickle: String,
|
||||||
pickling_mode: PicklingMode,
|
pickle_mode: PicklingMode,
|
||||||
shared: bool,
|
shared: bool,
|
||||||
) -> Result<Self, OlmAccountError> {
|
) -> Result<Self, OlmAccountError> {
|
||||||
let acc = OlmAccount::unpickle(pickle, pickling_mode)?;
|
let acc = OlmAccount::unpickle(pickle, pickle_mode)?;
|
||||||
Ok(Account { inner: acc, shared })
|
Ok(Account { inner: acc, shared })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn create_outbound_session(
|
||||||
|
&self,
|
||||||
|
their_identity_key: &str,
|
||||||
|
their_one_time_key: &str,
|
||||||
|
) -> Result<Session, OlmSessionError> {
|
||||||
|
let session = self
|
||||||
|
.inner
|
||||||
|
.create_outbound_session(their_identity_key, their_one_time_key)?;
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
Ok(Session {
|
||||||
|
inner: session,
|
||||||
|
sender_key: their_identity_key.to_owned(),
|
||||||
|
creation_time: now.clone(),
|
||||||
|
last_use_time: now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn create_inbound_session_from(
|
pub fn create_inbound_session_from(
|
||||||
&self,
|
&self,
|
||||||
their_identity_key: &str,
|
their_identity_key: &str,
|
||||||
|
@ -133,8 +152,8 @@ impl PartialEq for Account {
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
inner: OlmSession,
|
inner: OlmSession,
|
||||||
pub(crate) sender_key: String,
|
pub(crate) sender_key: String,
|
||||||
creation_time: Instant,
|
pub(crate) creation_time: Instant,
|
||||||
last_use_time: Instant,
|
pub(crate) last_use_time: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
|
@ -152,6 +171,36 @@ impl Session {
|
||||||
self.inner
|
self.inner
|
||||||
.matches_inbound_session_from(their_identity_key, message)
|
.matches_inbound_session_from(their_identity_key, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn session_id(&self) -> String {
|
||||||
|
self.inner.session_id()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pickle(&self, pickle_mode: PicklingMode) -> String {
|
||||||
|
self.inner.pickle(pickle_mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_pickle(
|
||||||
|
pickle: String,
|
||||||
|
pickle_mode: PicklingMode,
|
||||||
|
sender_key: String,
|
||||||
|
creation_time: Instant,
|
||||||
|
last_use_time: Instant,
|
||||||
|
) -> Result<Self, OlmSessionError> {
|
||||||
|
let session = OlmSession::unpickle(pickle, pickle_mode)?;
|
||||||
|
Ok(Session {
|
||||||
|
inner: session,
|
||||||
|
sender_key,
|
||||||
|
creation_time,
|
||||||
|
last_use_time,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for Session {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.session_id() == other.session_id()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
|
@ -20,6 +20,7 @@ use std::sync::Arc;
|
||||||
use url::ParseError;
|
use url::ParseError;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Error as SerdeError;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
@ -31,7 +32,7 @@ use olm_rs::PicklingMode;
|
||||||
pub mod sqlite;
|
pub mod sqlite;
|
||||||
|
|
||||||
#[cfg(feature = "sqlite-cryptostore")]
|
#[cfg(feature = "sqlite-cryptostore")]
|
||||||
use sqlx::{sqlite::Sqlite, Error as SqlxError};
|
use sqlx::Error as SqlxError;
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum CryptoStoreError {
|
pub enum CryptoStoreError {
|
||||||
|
@ -43,11 +44,17 @@ pub enum CryptoStoreError {
|
||||||
OlmSessionError(#[from] OlmSessionError),
|
OlmSessionError(#[from] OlmSessionError),
|
||||||
#[error("URL can't be parsed")]
|
#[error("URL can't be parsed")]
|
||||||
UrlParse(#[from] ParseError),
|
UrlParse(#[from] ParseError),
|
||||||
|
#[error("error serializing data for the database")]
|
||||||
|
Serialization(#[from] SerdeError),
|
||||||
|
#[error("can't load session timestamps")]
|
||||||
|
SessionTimestampError,
|
||||||
|
#[error("can't save/load sessions or group sessions in the store before a account is stored")]
|
||||||
|
AccountUnset,
|
||||||
// TODO flatten the SqlxError to make it easier for other store
|
// TODO flatten the SqlxError to make it easier for other store
|
||||||
// implementations.
|
// implementations.
|
||||||
#[cfg(feature = "sqlite-cryptostore")]
|
#[cfg(feature = "sqlite-cryptostore")]
|
||||||
#[error("database error")]
|
#[error("database error")]
|
||||||
DatabaseError(#[from] SqlxError<Sqlite>),
|
DatabaseError(#[from] SqlxError),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||||
|
@ -56,4 +63,6 @@ pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||||
pub trait CryptoStore: Debug + Send + Sync {
|
pub trait CryptoStore: Debug + Send + Sync {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||||
|
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
|
||||||
|
async fn load_sessions(&mut self) -> Result<Vec<Session>>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,19 +15,22 @@
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::result::Result as StdResult;
|
use std::result::Result as StdResult;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use olm_rs::PicklingMode;
|
use olm_rs::PicklingMode;
|
||||||
|
use serde_json;
|
||||||
use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection};
|
use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use zeroize::Zeroizing;
|
use zeroize::Zeroizing;
|
||||||
|
|
||||||
use super::{Account, CryptoStore, Result, Session};
|
use super::{Account, CryptoStore, CryptoStoreError, Result, Session};
|
||||||
|
|
||||||
pub struct SqliteStore {
|
pub struct SqliteStore {
|
||||||
user_id: Arc<String>,
|
user_id: Arc<String>,
|
||||||
device_id: Arc<String>,
|
device_id: Arc<String>,
|
||||||
|
account_id: Option<i64>,
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
connection: Arc<Mutex<SqliteConnection>>,
|
connection: Arc<Mutex<SqliteConnection>>,
|
||||||
pickle_passphrase: Option<Zeroizing<String>>,
|
pickle_passphrase: Option<Zeroizing<String>>,
|
||||||
|
@ -71,6 +74,7 @@ impl SqliteStore {
|
||||||
let store = SqliteStore {
|
let store = SqliteStore {
|
||||||
user_id: Arc::new(user_id.to_owned()),
|
user_id: Arc::new(user_id.to_owned()),
|
||||||
device_id: Arc::new(device_id.to_owned()),
|
device_id: Arc::new(device_id.to_owned()),
|
||||||
|
account_id: None,
|
||||||
path: path.as_ref().to_owned(),
|
path: path.as_ref().to_owned(),
|
||||||
connection: Arc::new(Mutex::new(connection)),
|
connection: Arc::new(Mutex::new(connection)),
|
||||||
pickle_passphrase: passphrase,
|
pickle_passphrase: passphrase,
|
||||||
|
@ -84,7 +88,7 @@ impl SqliteStore {
|
||||||
connection
|
connection
|
||||||
.execute(
|
.execute(
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE IF NOT EXISTS account (
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
"id" INTEGER NOT NULL PRIMARY KEY,
|
"id" INTEGER NOT NULL PRIMARY KEY,
|
||||||
"user_id" TEXT NOT NULL,
|
"user_id" TEXT NOT NULL,
|
||||||
"device_id" TEXT NOT NULL,
|
"device_id" TEXT NOT NULL,
|
||||||
|
@ -96,6 +100,25 @@ impl SqliteStore {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
connection
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
"session_id" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"account_id" INTEGER NOT NULL,
|
||||||
|
"creation_time" TEXT NOT NULL,
|
||||||
|
"last_use_time" TEXT NOT NULL,
|
||||||
|
"sender_key" TEXT NOT NULL,
|
||||||
|
"pickle" BLOB NOT NULL,
|
||||||
|
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
|
||||||
|
ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX "olmsessions_account_id" ON "sessions" ("account_id");
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,8 +137,8 @@ impl CryptoStore for SqliteStore {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
async fn load_account(&mut self) -> Result<Option<Account>> {
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
let row: Option<(String, bool)> = query_as(
|
let row: Option<(i64, String, bool)> = query_as(
|
||||||
"SELECT pickle, shared FROM account
|
"SELECT id, pickle, shared FROM accounts
|
||||||
WHERE user_id = ? and device_id = ?",
|
WHERE user_id = ? and device_id = ?",
|
||||||
)
|
)
|
||||||
.bind(&*self.user_id)
|
.bind(&*self.user_id)
|
||||||
|
@ -124,11 +147,14 @@ impl CryptoStore for SqliteStore {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let result = match row {
|
let result = match row {
|
||||||
Some((pickle, shared)) => Some(Account::from_pickle(
|
Some((id, pickle, shared)) => {
|
||||||
|
self.account_id = Some(id);
|
||||||
|
Some(Account::from_pickle(
|
||||||
pickle,
|
pickle,
|
||||||
self.get_pickle_mode(),
|
self.get_pickle_mode(),
|
||||||
shared,
|
shared,
|
||||||
)?),
|
)?)
|
||||||
|
}
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -141,7 +167,7 @@ impl CryptoStore for SqliteStore {
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
query(
|
query(
|
||||||
"INSERT OR IGNORE INTO account (
|
"INSERT OR IGNORE INTO accounts (
|
||||||
user_id, device_id, pickle, shared
|
user_id, device_id, pickle, shared
|
||||||
) VALUES (?, ?, ?, ?)",
|
) VALUES (?, ?, ?, ?)",
|
||||||
)
|
)
|
||||||
|
@ -153,7 +179,7 @@ impl CryptoStore for SqliteStore {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
query(
|
query(
|
||||||
"UPDATE account
|
"UPDATE accounts
|
||||||
SET pickle = ?,
|
SET pickle = ?,
|
||||||
shared = ?
|
shared = ?
|
||||||
WHERE user_id = ? and
|
WHERE user_id = ? and
|
||||||
|
@ -166,8 +192,82 @@ impl CryptoStore for SqliteStore {
|
||||||
.execute(&mut *connection)
|
.execute(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let account_id: (i64,) =
|
||||||
|
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
|
||||||
|
.bind(&*self.user_id)
|
||||||
|
.bind(&*self.device_id)
|
||||||
|
.fetch_one(&mut *connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.account_id = Some(account_id.0);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()> {
|
||||||
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
|
||||||
|
let session = session.lock().await;
|
||||||
|
|
||||||
|
let session_id = session.session_id();
|
||||||
|
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
|
||||||
|
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
|
||||||
|
let pickle = session.pickle(self.get_pickle_mode());
|
||||||
|
|
||||||
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
query(
|
||||||
|
"REPLACE INTO sessions (
|
||||||
|
session_id, account_id, creation_time, last_use_time, sender_key, pickle
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
)
|
||||||
|
.bind(&session_id)
|
||||||
|
.bind(&account_id)
|
||||||
|
.bind(&creation_time)
|
||||||
|
.bind(&last_use_time)
|
||||||
|
.bind(&session.sender_key)
|
||||||
|
.bind(&pickle)
|
||||||
|
.execute(&mut *connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_sessions(&mut self) -> Result<Vec<Session>> {
|
||||||
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
let rows: Vec<(String, String, String, String)> = query_as(
|
||||||
|
"SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ?"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.fetch_all(&mut *connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
Ok(rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| {
|
||||||
|
let pickle = &row.0;
|
||||||
|
let sender_key = &row.1;
|
||||||
|
let creation_time = now
|
||||||
|
.checked_sub(serde_json::from_str::<Duration>(&row.2)?)
|
||||||
|
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
||||||
|
let last_use_time = now
|
||||||
|
.checked_sub(serde_json::from_str::<Duration>(&row.3)?)
|
||||||
|
.ok_or(CryptoStoreError::SessionTimestampError)?;
|
||||||
|
|
||||||
|
Ok(Session::from_pickle(
|
||||||
|
pickle.to_string(),
|
||||||
|
self.get_pickle_mode(),
|
||||||
|
sender_key.to_string(),
|
||||||
|
creation_time,
|
||||||
|
last_use_time,
|
||||||
|
)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<Session>>>()?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for SqliteStore {
|
impl std::fmt::Debug for SqliteStore {
|
||||||
|
@ -186,7 +286,7 @@ mod test {
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use super::{Account, CryptoStore, SqliteStore};
|
use super::{Account, CryptoStore, Session, SqliteStore};
|
||||||
|
|
||||||
static USER_ID: &str = "@example:localhost";
|
static USER_ID: &str = "@example:localhost";
|
||||||
static DEVICE_ID: &str = "DEVICEID";
|
static DEVICE_ID: &str = "DEVICEID";
|
||||||
|
@ -204,6 +304,28 @@ mod test {
|
||||||
Arc::new(Mutex::new(account))
|
Arc::new(Mutex::new(account))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_account_and_session() -> (Arc<Mutex<Account>>, Arc<Mutex<Session>>) {
|
||||||
|
let alice = Account::new();
|
||||||
|
|
||||||
|
let bob = Account::new();
|
||||||
|
|
||||||
|
bob.generate_one_time_keys(1);
|
||||||
|
let one_time_key = bob
|
||||||
|
.one_time_keys()
|
||||||
|
.curve25519()
|
||||||
|
.iter()
|
||||||
|
.nth(0)
|
||||||
|
.unwrap()
|
||||||
|
.1
|
||||||
|
.to_owned();
|
||||||
|
let sender_key = bob.identity_keys().curve25519().to_owned();
|
||||||
|
let session = alice
|
||||||
|
.create_outbound_session(&sender_key, &one_time_key)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
(Arc::new(Mutex::new(alice)), Arc::new(Mutex::new(session)))
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn create_store() {
|
async fn create_store() {
|
||||||
let tmpdir = tempdir().unwrap();
|
let tmpdir = tempdir().unwrap();
|
||||||
|
@ -264,4 +386,35 @@ mod test {
|
||||||
|
|
||||||
assert_eq!(*acc, loaded_account);
|
assert_eq!(*acc, loaded_account);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn save_session() {
|
||||||
|
let mut store = get_store().await;
|
||||||
|
let (account, session) = get_account_and_session();
|
||||||
|
|
||||||
|
assert!(store.save_session(session.clone()).await.is_err());
|
||||||
|
|
||||||
|
store
|
||||||
|
.save_account(account.clone())
|
||||||
|
.await
|
||||||
|
.expect("Can't save account");
|
||||||
|
|
||||||
|
store.save_session(session).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn load_sessions() {
|
||||||
|
let mut store = get_store().await;
|
||||||
|
let (account, session) = get_account_and_session();
|
||||||
|
store
|
||||||
|
.save_account(account.clone())
|
||||||
|
.await
|
||||||
|
.expect("Can't save account");
|
||||||
|
store.save_session(session.clone()).await.unwrap();
|
||||||
|
|
||||||
|
let sess = session.lock().await;
|
||||||
|
|
||||||
|
let sessions = store.load_sessions().await.expect("Can't load sessions");
|
||||||
|
assert!(sessions.contains(&sess));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue