crypto: Implement group session loading for the sqlite store.
parent
b5b4542cd5
commit
559a5847bb
|
@ -31,7 +31,7 @@ impl SessionStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add(&mut self, session: Session) {
|
pub async fn add(&mut self, session: Session) -> Arc<Mutex<Session>> {
|
||||||
if !self.entries.contains_key(&session.sender_key) {
|
if !self.entries.contains_key(&session.sender_key) {
|
||||||
self.entries.insert(
|
self.entries.insert(
|
||||||
session.sender_key.to_owned(),
|
session.sender_key.to_owned(),
|
||||||
|
@ -39,7 +39,10 @@ impl SessionStore {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
|
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
|
||||||
sessions.lock().await.push(Arc::new(Mutex::new(session)));
|
let session = Arc::new(Mutex::new(session));
|
||||||
|
sessions.lock().await.push(session.clone());
|
||||||
|
|
||||||
|
session
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> {
|
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> {
|
||||||
|
|
|
@ -228,6 +228,23 @@ impl InboundGroupSession {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_pickle(
|
||||||
|
pickle: String,
|
||||||
|
pickle_mode: PicklingMode,
|
||||||
|
sender_key: String,
|
||||||
|
signing_key: String,
|
||||||
|
room_id: String,
|
||||||
|
) -> Result<Self, OlmGroupSessionError> {
|
||||||
|
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
|
||||||
|
Ok(InboundGroupSession {
|
||||||
|
inner: session,
|
||||||
|
sender_key,
|
||||||
|
signing_key,
|
||||||
|
room_id,
|
||||||
|
forwarding_chains: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn session_id(&self) -> String {
|
pub fn session_id(&self) -> String {
|
||||||
self.inner.session_id()
|
self.inner.session_id()
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,9 +61,8 @@ impl CryptoStore for MemoryStore {
|
||||||
Ok(self.sessions.get(sender_key))
|
Ok(self.sessions.get(sender_key))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
||||||
self.inbound_group_sessions.add(session);
|
Ok(self.inbound_group_sessions.add(session))
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
|
|
|
@ -25,7 +25,7 @@ use thiserror::Error;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use super::olm::{Account, InboundGroupSession, Session};
|
use super::olm::{Account, InboundGroupSession, Session};
|
||||||
use olm_rs::errors::{OlmAccountError, OlmSessionError};
|
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||||
use olm_rs::PicklingMode;
|
use olm_rs::PicklingMode;
|
||||||
|
|
||||||
pub mod memorystore;
|
pub mod memorystore;
|
||||||
|
@ -40,9 +40,11 @@ pub enum CryptoStoreError {
|
||||||
#[error("can't read or write from the store")]
|
#[error("can't read or write from the store")]
|
||||||
Io(#[from] IoError),
|
Io(#[from] IoError),
|
||||||
#[error("can't finish Olm Account operation {0}")]
|
#[error("can't finish Olm Account operation {0}")]
|
||||||
OlmAccountError(#[from] OlmAccountError),
|
OlmAccount(#[from] OlmAccountError),
|
||||||
#[error("can't finish Olm Session operation {0}")]
|
#[error("can't finish Olm Session operation {0}")]
|
||||||
OlmSessionError(#[from] OlmSessionError),
|
OlmSession(#[from] OlmSessionError),
|
||||||
|
#[error("can't finish Olm GruoupSession operation {0}")]
|
||||||
|
OlmGroupSession(#[from] OlmGroupSessionError),
|
||||||
#[error("URL can't be parsed")]
|
#[error("URL can't be parsed")]
|
||||||
UrlParse(#[from] ParseError),
|
UrlParse(#[from] ParseError),
|
||||||
#[error("error serializing data for the database")]
|
#[error("error serializing data for the database")]
|
||||||
|
@ -70,7 +72,7 @@ pub trait CryptoStore: Debug + Send + Sync {
|
||||||
&mut self,
|
&mut self,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
|
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>;
|
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
&mut self,
|
&mut self,
|
||||||
room_id: &str,
|
room_id: &str,
|
||||||
|
|
|
@ -201,32 +201,35 @@ impl SqliteStore {
|
||||||
.collect::<Result<Vec<Arc<Mutex<Session>>>>>()?)
|
.collect::<Result<Vec<Arc<Mutex<Session>>>>>()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
|
||||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let pickle = session.pickle(self.get_pickle_mode());
|
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
query(
|
let rows: Vec<(String, String, String, String)> = query_as(
|
||||||
"INSERT INTO inbound_group_sessions (
|
"SELECT pickle, sender_key, signing_key, room_id
|
||||||
session_id, account_id, sender_key, signing_key,
|
FROM inbound_group_sessions WHERE account_id = ?",
|
||||||
room_id, pickle
|
|
||||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
|
||||||
ON CONFLICT(session_id) DO UPDATE SET
|
|
||||||
pickle = ?6
|
|
||||||
WHERE session_id = ?1
|
|
||||||
",
|
|
||||||
)
|
)
|
||||||
.bind(&session.session_id())
|
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
.bind(&session.sender_key)
|
.fetch_all(&mut *connection)
|
||||||
.bind(&session.signing_key)
|
|
||||||
.bind(&session.room_id)
|
|
||||||
.bind(&pickle)
|
|
||||||
.execute(&mut *connection)
|
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
self.inbound_group_sessions.add(session);
|
Ok(rows
|
||||||
Ok(())
|
.iter()
|
||||||
|
.map(|row| {
|
||||||
|
let pickle = &row.0;
|
||||||
|
let sender_key = &row.1;
|
||||||
|
let signing_key = &row.2;
|
||||||
|
let room_id = &row.3;
|
||||||
|
|
||||||
|
Ok(InboundGroupSession::from_pickle(
|
||||||
|
pickle.to_string(),
|
||||||
|
self.get_pickle_mode(),
|
||||||
|
sender_key.to_string(),
|
||||||
|
signing_key.to_owned(),
|
||||||
|
room_id.to_owned(),
|
||||||
|
)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pickle_mode(&self) -> PicklingMode {
|
fn get_pickle_mode(&self) -> PicklingMode {
|
||||||
|
@ -265,6 +268,17 @@ impl CryptoStore for SqliteStore {
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
drop(connection);
|
||||||
|
|
||||||
|
let mut sessions = self.load_inbound_group_sessions().await?;
|
||||||
|
|
||||||
|
let _ = sessions
|
||||||
|
.drain(..)
|
||||||
|
.map(|s| {
|
||||||
|
self.inbound_group_sessions.add(s);
|
||||||
|
})
|
||||||
|
.collect::<()>();
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,7 +347,9 @@ impl CryptoStore for SqliteStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
|
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
|
||||||
todo!()
|
let session = self.sessions.add(session).await;
|
||||||
|
self.save_session(session).await?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_sessions(
|
async fn get_sessions(
|
||||||
|
@ -343,8 +359,32 @@ impl CryptoStore for SqliteStore {
|
||||||
Ok(self.get_sessions_for(sender_key).await?)
|
Ok(self.get_sessions_for(sender_key).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
||||||
todo!()
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
let pickle = session.pickle(self.get_pickle_mode());
|
||||||
|
let mut connection = self.connection.lock().await;
|
||||||
|
let session_id = session.session_id();
|
||||||
|
|
||||||
|
query(
|
||||||
|
"INSERT INTO inbound_group_sessions (
|
||||||
|
session_id, account_id, sender_key, signing_key,
|
||||||
|
room_id, pickle
|
||||||
|
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||||
|
ON CONFLICT(session_id) DO UPDATE SET
|
||||||
|
pickle = ?6
|
||||||
|
WHERE session_id = ?1
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(account_id)
|
||||||
|
.bind(&session.sender_key)
|
||||||
|
.bind(&session.signing_key)
|
||||||
|
.bind(&session.room_id)
|
||||||
|
.bind(&pickle)
|
||||||
|
.execute(&mut *connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(self.inbound_group_sessions.add(session))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
|
@ -353,7 +393,9 @@ impl CryptoStore for SqliteStore {
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
||||||
todo!()
|
Ok(self
|
||||||
|
.inbound_group_sessions
|
||||||
|
.get(room_id, sender_key, session_id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -387,12 +429,23 @@ mod test {
|
||||||
.expect("Can't create store")
|
.expect("Can't create store")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_loaded_store() -> (Arc<Mutex<Account>>, SqliteStore) {
|
||||||
|
let mut store = get_store().await;
|
||||||
|
let account = get_account();
|
||||||
|
store
|
||||||
|
.save_account(account.clone())
|
||||||
|
.await
|
||||||
|
.expect("Can't save account");
|
||||||
|
|
||||||
|
(account, store)
|
||||||
|
}
|
||||||
|
|
||||||
fn get_account() -> Arc<Mutex<Account>> {
|
fn get_account() -> Arc<Mutex<Account>> {
|
||||||
let account = Account::new();
|
let account = Account::new();
|
||||||
Arc::new(Mutex::new(account))
|
Arc::new(Mutex::new(account))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_account_and_session() -> (Arc<Mutex<Account>>, Arc<Mutex<Session>>) {
|
fn get_account_and_session() -> (Arc<Mutex<Account>>, Session) {
|
||||||
let alice = Account::new();
|
let alice = Account::new();
|
||||||
|
|
||||||
let bob = Account::new();
|
let bob = Account::new();
|
||||||
|
@ -411,7 +464,7 @@ mod test {
|
||||||
.create_outbound_session(&sender_key, &one_time_key)
|
.create_outbound_session(&sender_key, &one_time_key)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
(Arc::new(Mutex::new(alice)), Arc::new(Mutex::new(session)))
|
(Arc::new(Mutex::new(alice)), session)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -479,6 +532,7 @@ mod test {
|
||||||
async fn save_session() {
|
async fn save_session() {
|
||||||
let mut store = get_store().await;
|
let mut store = get_store().await;
|
||||||
let (account, session) = get_account_and_session();
|
let (account, session) = get_account_and_session();
|
||||||
|
let session = Arc::new(Mutex::new(session));
|
||||||
|
|
||||||
assert!(store.save_session(session.clone()).await.is_err());
|
assert!(store.save_session(session.clone()).await.is_err());
|
||||||
|
|
||||||
|
@ -494,6 +548,7 @@ mod test {
|
||||||
async fn load_sessions() {
|
async fn load_sessions() {
|
||||||
let mut store = get_store().await;
|
let mut store = get_store().await;
|
||||||
let (account, session) = get_account_and_session();
|
let (account, session) = get_account_and_session();
|
||||||
|
let session = Arc::new(Mutex::new(session));
|
||||||
store
|
store
|
||||||
.save_account(account.clone())
|
.save_account(account.clone())
|
||||||
.await
|
.await
|
||||||
|
@ -512,14 +567,28 @@ mod test {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn save_inbound_group_session() {
|
async fn add_and_save_session() {
|
||||||
let mut store = get_store().await;
|
let mut store = get_store().await;
|
||||||
let account = get_account();
|
let (account, session) = get_account_and_session();
|
||||||
|
let sender_key = session.sender_key.to_owned();
|
||||||
|
let session_id = session.session_id();
|
||||||
|
|
||||||
store
|
store
|
||||||
.save_account(account.clone())
|
.save_account(account.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Can't save account");
|
.expect("Can't save account");
|
||||||
|
store.add_and_save_session(session).await.unwrap();
|
||||||
|
|
||||||
|
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
|
||||||
|
let sessions_lock = sessions.lock().await;
|
||||||
|
let session = &sessions_lock[0];
|
||||||
|
|
||||||
|
assert_eq!(session_id, *session.lock().await.session_id());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn save_inbound_group_session() {
|
||||||
|
let (account, mut store) = get_loaded_store().await;
|
||||||
|
|
||||||
let acc = account.lock().await;
|
let acc = account.lock().await;
|
||||||
let identity_keys = acc.identity_keys();
|
let identity_keys = acc.identity_keys();
|
||||||
|
@ -537,4 +606,31 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.expect("Can't save group session");
|
.expect("Can't save group session");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn load_inbound_group_session() {
|
||||||
|
let (account, mut store) = get_loaded_store().await;
|
||||||
|
|
||||||
|
let acc = account.lock().await;
|
||||||
|
let identity_keys = acc.identity_keys();
|
||||||
|
let outbound_session = OlmOutboundGroupSession::new();
|
||||||
|
let session = InboundGroupSession::new(
|
||||||
|
identity_keys.curve25519(),
|
||||||
|
identity_keys.ed25519(),
|
||||||
|
"!test:localhost",
|
||||||
|
&outbound_session.session_key(),
|
||||||
|
)
|
||||||
|
.expect("Can't create session");
|
||||||
|
|
||||||
|
let session_id = session.session_id();
|
||||||
|
|
||||||
|
store
|
||||||
|
.save_inbound_group_session(session)
|
||||||
|
.await
|
||||||
|
.expect("Can't save group session");
|
||||||
|
|
||||||
|
let sessions = store.load_inbound_group_sessions().await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(session_id, sessions[0].session_id());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue