crypto: Implement group session loading for the sqlite store.
This commit is contained in:
parent
b5b4542cd5
commit
559a5847bb
5 changed files with 153 additions and 36 deletions
|
@ -31,7 +31,7 @@ impl SessionStore {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn add(&mut self, session: Session) {
|
||||
pub async fn add(&mut self, session: Session) -> Arc<Mutex<Session>> {
|
||||
if !self.entries.contains_key(&session.sender_key) {
|
||||
self.entries.insert(
|
||||
session.sender_key.to_owned(),
|
||||
|
@ -39,7 +39,10 @@ impl SessionStore {
|
|||
);
|
||||
}
|
||||
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
|
||||
sessions.lock().await.push(Arc::new(Mutex::new(session)));
|
||||
let session = Arc::new(Mutex::new(session));
|
||||
sessions.lock().await.push(session.clone());
|
||||
|
||||
session
|
||||
}
|
||||
|
||||
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> {
|
||||
|
|
|
@ -228,6 +228,23 @@ impl InboundGroupSession {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn from_pickle(
|
||||
pickle: String,
|
||||
pickle_mode: PicklingMode,
|
||||
sender_key: String,
|
||||
signing_key: String,
|
||||
room_id: String,
|
||||
) -> Result<Self, OlmGroupSessionError> {
|
||||
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
|
||||
Ok(InboundGroupSession {
|
||||
inner: session,
|
||||
sender_key,
|
||||
signing_key,
|
||||
room_id,
|
||||
forwarding_chains: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> String {
|
||||
self.inner.session_id()
|
||||
}
|
||||
|
|
|
@ -61,9 +61,8 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(self.sessions.get(sender_key))
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
||||
self.inbound_group_sessions.add(session);
|
||||
Ok(())
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
||||
Ok(self.inbound_group_sessions.add(session))
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
|
|
|
@ -25,7 +25,7 @@ use thiserror::Error;
|
|||
use tokio::sync::Mutex;
|
||||
|
||||
use super::olm::{Account, InboundGroupSession, Session};
|
||||
use olm_rs::errors::{OlmAccountError, OlmSessionError};
|
||||
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||
use olm_rs::PicklingMode;
|
||||
|
||||
pub mod memorystore;
|
||||
|
@ -40,9 +40,11 @@ pub enum CryptoStoreError {
|
|||
#[error("can't read or write from the store")]
|
||||
Io(#[from] IoError),
|
||||
#[error("can't finish Olm Account operation {0}")]
|
||||
OlmAccountError(#[from] OlmAccountError),
|
||||
OlmAccount(#[from] OlmAccountError),
|
||||
#[error("can't finish Olm Session operation {0}")]
|
||||
OlmSessionError(#[from] OlmSessionError),
|
||||
OlmSession(#[from] OlmSessionError),
|
||||
#[error("can't finish Olm GruoupSession operation {0}")]
|
||||
OlmGroupSession(#[from] OlmGroupSessionError),
|
||||
#[error("URL can't be parsed")]
|
||||
UrlParse(#[from] ParseError),
|
||||
#[error("error serializing data for the database")]
|
||||
|
@ -70,7 +72,7 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
&mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>;
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
|
||||
async fn get_inbound_group_session(
|
||||
&mut self,
|
||||
room_id: &str,
|
||||
|
|
|
@ -201,32 +201,35 @@ impl SqliteStore {
|
|||
.collect::<Result<Vec<Arc<Mutex<Session>>>>>()?)
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
||||
async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let pickle = session.pickle(self.get_pickle_mode());
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
query(
|
||||
"INSERT INTO inbound_group_sessions (
|
||||
session_id, account_id, sender_key, signing_key,
|
||||
room_id, pickle
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
ON CONFLICT(session_id) DO UPDATE SET
|
||||
pickle = ?6
|
||||
WHERE session_id = ?1
|
||||
",
|
||||
let rows: Vec<(String, String, String, String)> = query_as(
|
||||
"SELECT pickle, sender_key, signing_key, room_id
|
||||
FROM inbound_group_sessions WHERE account_id = ?",
|
||||
)
|
||||
.bind(&session.session_id())
|
||||
.bind(account_id)
|
||||
.bind(&session.sender_key)
|
||||
.bind(&session.signing_key)
|
||||
.bind(&session.room_id)
|
||||
.bind(&pickle)
|
||||
.execute(&mut *connection)
|
||||
.fetch_all(&mut *connection)
|
||||
.await?;
|
||||
|
||||
self.inbound_group_sessions.add(session);
|
||||
Ok(())
|
||||
Ok(rows
|
||||
.iter()
|
||||
.map(|row| {
|
||||
let pickle = &row.0;
|
||||
let sender_key = &row.1;
|
||||
let signing_key = &row.2;
|
||||
let room_id = &row.3;
|
||||
|
||||
Ok(InboundGroupSession::from_pickle(
|
||||
pickle.to_string(),
|
||||
self.get_pickle_mode(),
|
||||
sender_key.to_string(),
|
||||
signing_key.to_owned(),
|
||||
room_id.to_owned(),
|
||||
)?)
|
||||
})
|
||||
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
||||
}
|
||||
|
||||
fn get_pickle_mode(&self) -> PicklingMode {
|
||||
|
@ -265,6 +268,17 @@ impl CryptoStore for SqliteStore {
|
|||
None => None,
|
||||
};
|
||||
|
||||
drop(connection);
|
||||
|
||||
let mut sessions = self.load_inbound_group_sessions().await?;
|
||||
|
||||
let _ = sessions
|
||||
.drain(..)
|
||||
.map(|s| {
|
||||
self.inbound_group_sessions.add(s);
|
||||
})
|
||||
.collect::<()>();
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
@ -333,7 +347,9 @@ impl CryptoStore for SqliteStore {
|
|||
}
|
||||
|
||||
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
|
||||
todo!()
|
||||
let session = self.sessions.add(session).await;
|
||||
self.save_session(session).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_sessions(
|
||||
|
@ -343,8 +359,32 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(self.get_sessions_for(sender_key).await?)
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
||||
todo!()
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let pickle = session.pickle(self.get_pickle_mode());
|
||||
let mut connection = self.connection.lock().await;
|
||||
let session_id = session.session_id();
|
||||
|
||||
query(
|
||||
"INSERT INTO inbound_group_sessions (
|
||||
session_id, account_id, sender_key, signing_key,
|
||||
room_id, pickle
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
ON CONFLICT(session_id) DO UPDATE SET
|
||||
pickle = ?6
|
||||
WHERE session_id = ?1
|
||||
",
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(account_id)
|
||||
.bind(&session.sender_key)
|
||||
.bind(&session.signing_key)
|
||||
.bind(&session.room_id)
|
||||
.bind(&pickle)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
|
||||
Ok(self.inbound_group_sessions.add(session))
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
|
@ -353,7 +393,9 @@ impl CryptoStore for SqliteStore {
|
|||
sender_key: &str,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
||||
todo!()
|
||||
Ok(self
|
||||
.inbound_group_sessions
|
||||
.get(room_id, sender_key, session_id))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -387,12 +429,23 @@ mod test {
|
|||
.expect("Can't create store")
|
||||
}
|
||||
|
||||
async fn get_loaded_store() -> (Arc<Mutex<Account>>, SqliteStore) {
|
||||
let mut store = get_store().await;
|
||||
let account = get_account();
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
|
||||
(account, store)
|
||||
}
|
||||
|
||||
fn get_account() -> Arc<Mutex<Account>> {
|
||||
let account = Account::new();
|
||||
Arc::new(Mutex::new(account))
|
||||
}
|
||||
|
||||
fn get_account_and_session() -> (Arc<Mutex<Account>>, Arc<Mutex<Session>>) {
|
||||
fn get_account_and_session() -> (Arc<Mutex<Account>>, Session) {
|
||||
let alice = Account::new();
|
||||
|
||||
let bob = Account::new();
|
||||
|
@ -411,7 +464,7 @@ mod test {
|
|||
.create_outbound_session(&sender_key, &one_time_key)
|
||||
.unwrap();
|
||||
|
||||
(Arc::new(Mutex::new(alice)), Arc::new(Mutex::new(session)))
|
||||
(Arc::new(Mutex::new(alice)), session)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -479,6 +532,7 @@ mod test {
|
|||
async fn save_session() {
|
||||
let mut store = get_store().await;
|
||||
let (account, session) = get_account_and_session();
|
||||
let session = Arc::new(Mutex::new(session));
|
||||
|
||||
assert!(store.save_session(session.clone()).await.is_err());
|
||||
|
||||
|
@ -494,6 +548,7 @@ mod test {
|
|||
async fn load_sessions() {
|
||||
let mut store = get_store().await;
|
||||
let (account, session) = get_account_and_session();
|
||||
let session = Arc::new(Mutex::new(session));
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
|
@ -512,14 +567,28 @@ mod test {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_inbound_group_session() {
|
||||
async fn add_and_save_session() {
|
||||
let mut store = get_store().await;
|
||||
let account = get_account();
|
||||
let (account, session) = get_account_and_session();
|
||||
let sender_key = session.sender_key.to_owned();
|
||||
let session_id = session.session_id();
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.add_and_save_session(session).await.unwrap();
|
||||
|
||||
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
|
||||
let sessions_lock = sessions.lock().await;
|
||||
let session = &sessions_lock[0];
|
||||
|
||||
assert_eq!(session_id, *session.lock().await.session_id());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_inbound_group_session() {
|
||||
let (account, mut store) = get_loaded_store().await;
|
||||
|
||||
let acc = account.lock().await;
|
||||
let identity_keys = acc.identity_keys();
|
||||
|
@ -537,4 +606,31 @@ mod test {
|
|||
.await
|
||||
.expect("Can't save group session");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_inbound_group_session() {
|
||||
let (account, mut store) = get_loaded_store().await;
|
||||
|
||||
let acc = account.lock().await;
|
||||
let identity_keys = acc.identity_keys();
|
||||
let outbound_session = OlmOutboundGroupSession::new();
|
||||
let session = InboundGroupSession::new(
|
||||
identity_keys.curve25519(),
|
||||
identity_keys.ed25519(),
|
||||
"!test:localhost",
|
||||
&outbound_session.session_key(),
|
||||
)
|
||||
.expect("Can't create session");
|
||||
|
||||
let session_id = session.session_id();
|
||||
|
||||
store
|
||||
.save_inbound_group_session(session)
|
||||
.await
|
||||
.expect("Can't save group session");
|
||||
|
||||
let sessions = store.load_inbound_group_sessions().await.unwrap();
|
||||
|
||||
assert_eq!(session_id, sessions[0].session_id());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue