diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 5b7b5820..e86961a1 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -31,7 +31,7 @@ impl SessionStore { } } - pub async fn add(&mut self, session: Session) { + pub async fn add(&mut self, session: Session) -> Arc> { if !self.entries.contains_key(&session.sender_key) { self.entries.insert( session.sender_key.to_owned(), @@ -39,7 +39,10 @@ impl SessionStore { ); } let mut sessions = self.entries.get_mut(&session.sender_key).unwrap(); - sessions.lock().await.push(Arc::new(Mutex::new(session))); + let session = Arc::new(Mutex::new(session)); + sessions.lock().await.push(session.clone()); + + session } pub fn get(&self, sender_key: &str) -> Option>>>>> { diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 4a470db8..1165b899 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -228,6 +228,23 @@ impl InboundGroupSession { }) } + pub fn from_pickle( + pickle: String, + pickle_mode: PicklingMode, + sender_key: String, + signing_key: String, + room_id: String, + ) -> Result { + let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?; + Ok(InboundGroupSession { + inner: session, + sender_key, + signing_key, + room_id, + forwarding_chains: None, + }) + } + pub fn session_id(&self) -> String { self.inner.session_id() } diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index f3744f79..9c226096 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -61,9 +61,8 @@ impl CryptoStore for MemoryStore { Ok(self.sessions.get(sender_key)) } - async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { - self.inbound_group_sessions.add(session); - Ok(()) + async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result { + Ok(self.inbound_group_sessions.add(session)) } async fn get_inbound_group_session( diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index ae2f1182..f2ee300a 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -25,7 +25,7 @@ use thiserror::Error; use tokio::sync::Mutex; use super::olm::{Account, InboundGroupSession, Session}; -use olm_rs::errors::{OlmAccountError, OlmSessionError}; +use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::PicklingMode; pub mod memorystore; @@ -40,9 +40,11 @@ pub enum CryptoStoreError { #[error("can't read or write from the store")] Io(#[from] IoError), #[error("can't finish Olm Account operation {0}")] - OlmAccountError(#[from] OlmAccountError), + OlmAccount(#[from] OlmAccountError), #[error("can't finish Olm Session operation {0}")] - OlmSessionError(#[from] OlmSessionError), + OlmSession(#[from] OlmSessionError), + #[error("can't finish Olm GruoupSession operation {0}")] + OlmGroupSession(#[from] OlmGroupSessionError), #[error("URL can't be parsed")] UrlParse(#[from] ParseError), #[error("error serializing data for the database")] @@ -70,7 +72,7 @@ pub trait CryptoStore: Debug + Send + Sync { &mut self, sender_key: &str, ) -> Result>>>>>>; - async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>; + async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result; async fn get_inbound_group_session( &mut self, room_id: &str, diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 5cd57e8a..01bfdda0 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -201,32 +201,35 @@ impl SqliteStore { .collect::>>>>()?) } - async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { + async fn load_inbound_group_sessions(&self) -> Result> { let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; - let pickle = session.pickle(self.get_pickle_mode()); let mut connection = self.connection.lock().await; - query( - "INSERT INTO inbound_group_sessions ( - session_id, account_id, sender_key, signing_key, - room_id, pickle - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) - ON CONFLICT(session_id) DO UPDATE SET - pickle = ?6 - WHERE session_id = ?1 - ", + let rows: Vec<(String, String, String, String)> = query_as( + "SELECT pickle, sender_key, signing_key, room_id + FROM inbound_group_sessions WHERE account_id = ?", ) - .bind(&session.session_id()) .bind(account_id) - .bind(&session.sender_key) - .bind(&session.signing_key) - .bind(&session.room_id) - .bind(&pickle) - .execute(&mut *connection) + .fetch_all(&mut *connection) .await?; - self.inbound_group_sessions.add(session); - Ok(()) + Ok(rows + .iter() + .map(|row| { + let pickle = &row.0; + let sender_key = &row.1; + let signing_key = &row.2; + let room_id = &row.3; + + Ok(InboundGroupSession::from_pickle( + pickle.to_string(), + self.get_pickle_mode(), + sender_key.to_string(), + signing_key.to_owned(), + room_id.to_owned(), + )?) + }) + .collect::>>()?) } fn get_pickle_mode(&self) -> PicklingMode { @@ -265,6 +268,17 @@ impl CryptoStore for SqliteStore { None => None, }; + drop(connection); + + let mut sessions = self.load_inbound_group_sessions().await?; + + let _ = sessions + .drain(..) + .map(|s| { + self.inbound_group_sessions.add(s); + }) + .collect::<()>(); + Ok(result) } @@ -333,7 +347,9 @@ impl CryptoStore for SqliteStore { } async fn add_and_save_session(&mut self, session: Session) -> Result<()> { - todo!() + let session = self.sessions.add(session).await; + self.save_session(session).await?; + Ok(()) } async fn get_sessions( @@ -343,8 +359,32 @@ impl CryptoStore for SqliteStore { Ok(self.get_sessions_for(sender_key).await?) } - async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { - todo!() + async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let pickle = session.pickle(self.get_pickle_mode()); + let mut connection = self.connection.lock().await; + let session_id = session.session_id(); + + query( + "INSERT INTO inbound_group_sessions ( + session_id, account_id, sender_key, signing_key, + room_id, pickle + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ON CONFLICT(session_id) DO UPDATE SET + pickle = ?6 + WHERE session_id = ?1 + ", + ) + .bind(session_id) + .bind(account_id) + .bind(&session.sender_key) + .bind(&session.signing_key) + .bind(&session.room_id) + .bind(&pickle) + .execute(&mut *connection) + .await?; + + Ok(self.inbound_group_sessions.add(session)) } async fn get_inbound_group_session( @@ -353,7 +393,9 @@ impl CryptoStore for SqliteStore { sender_key: &str, session_id: &str, ) -> Result>>> { - todo!() + Ok(self + .inbound_group_sessions + .get(room_id, sender_key, session_id)) } } @@ -387,12 +429,23 @@ mod test { .expect("Can't create store") } + async fn get_loaded_store() -> (Arc>, SqliteStore) { + let mut store = get_store().await; + let account = get_account(); + store + .save_account(account.clone()) + .await + .expect("Can't save account"); + + (account, store) + } + fn get_account() -> Arc> { let account = Account::new(); Arc::new(Mutex::new(account)) } - fn get_account_and_session() -> (Arc>, Arc>) { + fn get_account_and_session() -> (Arc>, Session) { let alice = Account::new(); let bob = Account::new(); @@ -411,7 +464,7 @@ mod test { .create_outbound_session(&sender_key, &one_time_key) .unwrap(); - (Arc::new(Mutex::new(alice)), Arc::new(Mutex::new(session))) + (Arc::new(Mutex::new(alice)), session) } #[tokio::test] @@ -479,6 +532,7 @@ mod test { async fn save_session() { let mut store = get_store().await; let (account, session) = get_account_and_session(); + let session = Arc::new(Mutex::new(session)); assert!(store.save_session(session.clone()).await.is_err()); @@ -494,6 +548,7 @@ mod test { async fn load_sessions() { let mut store = get_store().await; let (account, session) = get_account_and_session(); + let session = Arc::new(Mutex::new(session)); store .save_account(account.clone()) .await @@ -512,14 +567,28 @@ mod test { } #[tokio::test] - async fn save_inbound_group_session() { + async fn add_and_save_session() { let mut store = get_store().await; - let account = get_account(); + let (account, session) = get_account_and_session(); + let sender_key = session.sender_key.to_owned(); + let session_id = session.session_id(); store .save_account(account.clone()) .await .expect("Can't save account"); + store.add_and_save_session(session).await.unwrap(); + + let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); + let sessions_lock = sessions.lock().await; + let session = &sessions_lock[0]; + + assert_eq!(session_id, *session.lock().await.session_id()); + } + + #[tokio::test] + async fn save_inbound_group_session() { + let (account, mut store) = get_loaded_store().await; let acc = account.lock().await; let identity_keys = acc.identity_keys(); @@ -537,4 +606,31 @@ mod test { .await .expect("Can't save group session"); } + + #[tokio::test] + async fn load_inbound_group_session() { + let (account, mut store) = get_loaded_store().await; + + let acc = account.lock().await; + let identity_keys = acc.identity_keys(); + let outbound_session = OlmOutboundGroupSession::new(); + let session = InboundGroupSession::new( + identity_keys.curve25519(), + identity_keys.ed25519(), + "!test:localhost", + &outbound_session.session_key(), + ) + .expect("Can't create session"); + + let session_id = session.session_id(); + + store + .save_inbound_group_session(session) + .await + .expect("Can't save group session"); + + let sessions = store.load_inbound_group_sessions().await.unwrap(); + + assert_eq!(session_id, sessions[0].session_id()); + } }