diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 00dfccd5..e55bcb8b 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -261,7 +261,24 @@ impl SqliteStore { UNIQUE(session_id, algorithm) ); - CREATE INDEX IF NOT EXISTS "group_session_claimed_keys_session_id" ON "inbound_group_sessions" ("session_id"); + CREATE INDEX IF NOT EXISTS "group_session_claimed_keys_session_id" ON "group_session_claimed_keys" ("session_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS group_session_chains ( + "id" INTEGER NOT NULL PRIMARY KEY, + "key" TEXT NOT NULL, + "session_id" INTEGER NOT NULL, + FOREIGN KEY ("session_id") REFERENCES "inbound_group_sessions" ("id") + ON DELETE CASCADE + UNIQUE(session_id, key) + ); + + CREATE INDEX IF NOT EXISTS "group_session_chains_session_id" ON "group_session_chains" ("session_id"); "#, ) .await?; @@ -527,14 +544,26 @@ impl SqliteStore { }) .collect(); + let mut chain_rows: Vec<(String,)> = + query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?") + .bind(session_row_id) + .fetch_all(&mut *connection) + .await?; + + let chains: Vec = chain_rows.drain(..).map(|r| r.0).collect(); + + let chains = if chains.is_empty() { + None + } else { + Some(chains) + }; + let pickle = PickledInboundGroupSession { pickle: InboundGroupSessionPickle::from(pickle), sender_key, signing_key: claimed_keys, room_id: RoomId::try_from(room_id)?, - // Fixme we need to store/restore these once we get support - // for key requesting/forwarding. - forwarding_chains: None, + forwarding_chains: chains, imported, }; @@ -805,10 +834,6 @@ impl SqliteStore { let pickle = session.pickle(self.get_pickle_mode()).await; let session_id = session.session_id(); - // FIXME we need to store/restore the forwarding chains. - // FIXME this should be converted so it accepts an array of sessions for - // the key import feature. - query( "REPLACE INTO inbound_group_sessions ( session_id, account_id, sender_key, @@ -839,7 +864,7 @@ impl SqliteStore { for (key_id, key) in pickle.signing_key { query( - "INSERT OR IGNORE INTO group_session_claimed_keys ( + "REPLACE INTO group_session_claimed_keys ( session_id, algorithm, key ) VALUES (?1, ?2, ?3) ", @@ -851,6 +876,21 @@ impl SqliteStore { .await?; } + if let Some(chains) = pickle.forwarding_chains { + for key in chains { + query( + "REPLACE INTO group_session_chains ( + session_id, key + ) VALUES (?1, ?2) + ", + ) + .bind(session_row_id) + .bind(key) + .execute(&mut *connection) + .await?; + } + } + Ok(()) } @@ -1606,7 +1646,7 @@ mod test { #[tokio::test] async fn load_inbound_group_session() { - let (account, store, _dir) = get_loaded_store().await; + let (account, store, dir) = get_loaded_store().await; let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); @@ -1618,11 +1658,22 @@ mod test { ) .expect("Can't create session"); + let mut export = session.export().await; + + export.forwarding_curve25519_key_chain = vec!["some_chain".to_owned()]; + + let session = InboundGroupSession::from_export(export).unwrap(); + store .save_inbound_group_sessions(&[session.clone()]) .await .expect("Can't save group session"); + let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) + .await + .expect("Can't create store"); + + store.load_account().await.unwrap(); store.load_inbound_group_sessions().await.unwrap(); let loaded_session = store @@ -1631,6 +1682,8 @@ mod test { .unwrap() .unwrap(); assert_eq!(session, loaded_session); + let export = loaded_session.export().await; + assert!(!export.forwarding_curve25519_key_chain.is_empty()) } #[tokio::test]