crypto: Implement group session loading for the sqlite store.

master
Damir Jelić 2020-03-31 16:19:08 +02:00
parent b5b4542cd5
commit 559a5847bb
5 changed files with 153 additions and 36 deletions

View File

@ -31,7 +31,7 @@ impl SessionStore {
} }
} }
pub async fn add(&mut self, session: Session) { pub async fn add(&mut self, session: Session) -> Arc<Mutex<Session>> {
if !self.entries.contains_key(&session.sender_key) { if !self.entries.contains_key(&session.sender_key) {
self.entries.insert( self.entries.insert(
session.sender_key.to_owned(), session.sender_key.to_owned(),
@ -39,7 +39,10 @@ impl SessionStore {
); );
} }
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap(); let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
sessions.lock().await.push(Arc::new(Mutex::new(session))); let session = Arc::new(Mutex::new(session));
sessions.lock().await.push(session.clone());
session
} }
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> { pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> {

View File

@ -228,6 +228,23 @@ impl InboundGroupSession {
}) })
} }
pub fn from_pickle(
pickle: String,
pickle_mode: PicklingMode,
sender_key: String,
signing_key: String,
room_id: String,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
Ok(InboundGroupSession {
inner: session,
sender_key,
signing_key,
room_id,
forwarding_chains: None,
})
}
pub fn session_id(&self) -> String { pub fn session_id(&self) -> String {
self.inner.session_id() self.inner.session_id()
} }

View File

@ -61,9 +61,8 @@ impl CryptoStore for MemoryStore {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
self.inbound_group_sessions.add(session); Ok(self.inbound_group_sessions.add(session))
Ok(())
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(

View File

@ -25,7 +25,7 @@ use thiserror::Error;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use super::olm::{Account, InboundGroupSession, Session}; use super::olm::{Account, InboundGroupSession, Session};
use olm_rs::errors::{OlmAccountError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use olm_rs::PicklingMode; use olm_rs::PicklingMode;
pub mod memorystore; pub mod memorystore;
@ -40,9 +40,11 @@ pub enum CryptoStoreError {
#[error("can't read or write from the store")] #[error("can't read or write from the store")]
Io(#[from] IoError), Io(#[from] IoError),
#[error("can't finish Olm Account operation {0}")] #[error("can't finish Olm Account operation {0}")]
OlmAccountError(#[from] OlmAccountError), OlmAccount(#[from] OlmAccountError),
#[error("can't finish Olm Session operation {0}")] #[error("can't finish Olm Session operation {0}")]
OlmSessionError(#[from] OlmSessionError), OlmSession(#[from] OlmSessionError),
#[error("can't finish Olm GruoupSession operation {0}")]
OlmGroupSession(#[from] OlmGroupSessionError),
#[error("URL can't be parsed")] #[error("URL can't be parsed")]
UrlParse(#[from] ParseError), UrlParse(#[from] ParseError),
#[error("error serializing data for the database")] #[error("error serializing data for the database")]
@ -70,7 +72,7 @@ pub trait CryptoStore: Debug + Send + Sync {
&mut self, &mut self,
sender_key: &str, sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>; ) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>; async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &mut self,
room_id: &str, room_id: &str,

View File

@ -201,32 +201,35 @@ impl SqliteStore {
.collect::<Result<Vec<Arc<Mutex<Session>>>>>()?) .collect::<Result<Vec<Arc<Mutex<Session>>>>>()?)
} }
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode());
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
query( let rows: Vec<(String, String, String, String)> = query_as(
"INSERT INTO inbound_group_sessions ( "SELECT pickle, sender_key, signing_key, room_id
session_id, account_id, sender_key, signing_key, FROM inbound_group_sessions WHERE account_id = ?",
room_id, pickle
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
ON CONFLICT(session_id) DO UPDATE SET
pickle = ?6
WHERE session_id = ?1
",
) )
.bind(&session.session_id())
.bind(account_id) .bind(account_id)
.bind(&session.sender_key) .fetch_all(&mut *connection)
.bind(&session.signing_key)
.bind(&session.room_id)
.bind(&pickle)
.execute(&mut *connection)
.await?; .await?;
self.inbound_group_sessions.add(session); Ok(rows
Ok(()) .iter()
.map(|row| {
let pickle = &row.0;
let sender_key = &row.1;
let signing_key = &row.2;
let room_id = &row.3;
Ok(InboundGroupSession::from_pickle(
pickle.to_string(),
self.get_pickle_mode(),
sender_key.to_string(),
signing_key.to_owned(),
room_id.to_owned(),
)?)
})
.collect::<Result<Vec<InboundGroupSession>>>()?)
} }
fn get_pickle_mode(&self) -> PicklingMode { fn get_pickle_mode(&self) -> PicklingMode {
@ -265,6 +268,17 @@ impl CryptoStore for SqliteStore {
None => None, None => None,
}; };
drop(connection);
let mut sessions = self.load_inbound_group_sessions().await?;
let _ = sessions
.drain(..)
.map(|s| {
self.inbound_group_sessions.add(s);
})
.collect::<()>();
Ok(result) Ok(result)
} }
@ -333,7 +347,9 @@ impl CryptoStore for SqliteStore {
} }
async fn add_and_save_session(&mut self, session: Session) -> Result<()> { async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
todo!() let session = self.sessions.add(session).await;
self.save_session(session).await?;
Ok(())
} }
async fn get_sessions( async fn get_sessions(
@ -343,8 +359,32 @@ impl CryptoStore for SqliteStore {
Ok(self.get_sessions_for(sender_key).await?) Ok(self.get_sessions_for(sender_key).await?)
} }
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
todo!() let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode());
let mut connection = self.connection.lock().await;
let session_id = session.session_id();
query(
"INSERT INTO inbound_group_sessions (
session_id, account_id, sender_key, signing_key,
room_id, pickle
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
ON CONFLICT(session_id) DO UPDATE SET
pickle = ?6
WHERE session_id = ?1
",
)
.bind(session_id)
.bind(account_id)
.bind(&session.sender_key)
.bind(&session.signing_key)
.bind(&session.room_id)
.bind(&pickle)
.execute(&mut *connection)
.await?;
Ok(self.inbound_group_sessions.add(session))
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
@ -353,7 +393,9 @@ impl CryptoStore for SqliteStore {
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> { ) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
todo!() Ok(self
.inbound_group_sessions
.get(room_id, sender_key, session_id))
} }
} }
@ -387,12 +429,23 @@ mod test {
.expect("Can't create store") .expect("Can't create store")
} }
async fn get_loaded_store() -> (Arc<Mutex<Account>>, SqliteStore) {
let mut store = get_store().await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
(account, store)
}
fn get_account() -> Arc<Mutex<Account>> { fn get_account() -> Arc<Mutex<Account>> {
let account = Account::new(); let account = Account::new();
Arc::new(Mutex::new(account)) Arc::new(Mutex::new(account))
} }
fn get_account_and_session() -> (Arc<Mutex<Account>>, Arc<Mutex<Session>>) { fn get_account_and_session() -> (Arc<Mutex<Account>>, Session) {
let alice = Account::new(); let alice = Account::new();
let bob = Account::new(); let bob = Account::new();
@ -411,7 +464,7 @@ mod test {
.create_outbound_session(&sender_key, &one_time_key) .create_outbound_session(&sender_key, &one_time_key)
.unwrap(); .unwrap();
(Arc::new(Mutex::new(alice)), Arc::new(Mutex::new(session))) (Arc::new(Mutex::new(alice)), session)
} }
#[tokio::test] #[tokio::test]
@ -479,6 +532,7 @@ mod test {
async fn save_session() { async fn save_session() {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session(); let (account, session) = get_account_and_session();
let session = Arc::new(Mutex::new(session));
assert!(store.save_session(session.clone()).await.is_err()); assert!(store.save_session(session.clone()).await.is_err());
@ -494,6 +548,7 @@ mod test {
async fn load_sessions() { async fn load_sessions() {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session(); let (account, session) = get_account_and_session();
let session = Arc::new(Mutex::new(session));
store store
.save_account(account.clone()) .save_account(account.clone())
.await .await
@ -512,14 +567,28 @@ mod test {
} }
#[tokio::test] #[tokio::test]
async fn save_inbound_group_session() { async fn add_and_save_session() {
let mut store = get_store().await; let mut store = get_store().await;
let account = get_account(); let (account, session) = get_account_and_session();
let sender_key = session.sender_key.to_owned();
let session_id = session.session_id();
store store
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.add_and_save_session(session).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
let sessions_lock = sessions.lock().await;
let session = &sessions_lock[0];
assert_eq!(session_id, *session.lock().await.session_id());
}
#[tokio::test]
async fn save_inbound_group_session() {
let (account, mut store) = get_loaded_store().await;
let acc = account.lock().await; let acc = account.lock().await;
let identity_keys = acc.identity_keys(); let identity_keys = acc.identity_keys();
@ -537,4 +606,31 @@ mod test {
.await .await
.expect("Can't save group session"); .expect("Can't save group session");
} }
#[tokio::test]
async fn load_inbound_group_session() {
let (account, mut store) = get_loaded_store().await;
let acc = account.lock().await;
let identity_keys = acc.identity_keys();
let outbound_session = OlmOutboundGroupSession::new();
let session = InboundGroupSession::new(
identity_keys.curve25519(),
identity_keys.ed25519(),
"!test:localhost",
&outbound_session.session_key(),
)
.expect("Can't create session");
let session_id = session.session_id();
store
.save_inbound_group_session(session)
.await
.expect("Can't save group session");
let sessions = store.load_inbound_group_sessions().await.unwrap();
assert_eq!(session_id, sessions[0].session_id());
}
} }