diff --git a/matrix_sdk_crypto/src/identities/user.rs b/matrix_sdk_crypto/src/identities/user.rs index d3a78979..5893b7ad 100644 --- a/matrix_sdk_crypto/src/identities/user.rs +++ b/matrix_sdk_crypto/src/identities/user.rs @@ -756,6 +756,10 @@ pub(crate) mod test { OwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into()).unwrap() } + pub(crate) fn get_own_identity() -> OwnUserIdentity { + own_identity(&own_key_query()) + } + #[test] fn own_identity_create() { let user_id = user_id!("@example:localhost"); diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 50b0b5f4..f4619a59 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -21,8 +21,9 @@ use std::{ }; use async_trait::async_trait; -use dashmap::DashSet; +use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ + api::r0::keys::{CrossSigningKey, KeyUsage}, identifiers::{ DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId, }, @@ -135,8 +136,8 @@ impl SqliteStore { passphrase: Option>, ) -> Result { let url = SqliteStore::path_to_url(path.as_ref())?; - let connection = SqliteConnection::connect(url.as_ref()).await?; + let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.into()), @@ -151,6 +152,7 @@ impl SqliteStore { users_for_key_query: Arc::new(DashSet::new()), }; store.create_tables().await?; + Ok(store) } @@ -310,6 +312,61 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS users ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "user_id" TEXT NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id,user_id) + ); + + CREATE INDEX IF NOT EXISTS "users_account_id" ON "users" ("account_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS user_keys ( + "id" INTEGER NOT NULL PRIMARY KEY, + "key" TEXT NOT NULL, + "key_id" TEXT NOT NULL, + "key_type" TEXT NOT NULL, + "usage" TEXT NOT NULL, + "user_id" INTEGER NOT NULL, + FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE + UNIQUE(user_id, key_id, key_type) + ); + + CREATE INDEX IF NOT EXISTS "user_keys_user_id" ON "users" ("user_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS user_key_signatures ( + "id" INTEGER NOT NULL PRIMARY KEY, + "user_id" TEXT NOT NULL, + "key_id" INTEGER NOT NULL, + "signature" TEXT NOT NULL, + "user_key" INTEGER NOT NULL, + FOREIGN KEY ("user_key") REFERENCES "user_keys" ("id") + ON DELETE CASCADE + UNIQUE(user_id, key_id, user_key) + ); + + CREATE INDEX IF NOT EXISTS "user_keys_device_id" ON "device_keys" ("device_id"); + "#, + ) + .await?; + Ok(()) } @@ -670,6 +727,159 @@ impl SqliteStore { None => PicklingMode::Unencrypted, } } + + async fn load_user(&self, user_id: &UserId) -> Result> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let row: Option<(i64,)> = + query_as("SELECT id FROM users WHERE account_id = ? and user_id = ?") + .bind(account_id) + .bind(user_id.as_str()) + .fetch_optional(&mut *connection) + .await?; + + let user_row_id = if let Some(row) = row { + row.0 + } else { + return Ok(None); + }; + + let key_rows: Vec<(i64, String, String, String)> = query_as( + "SELECT id, key_id, key, usage FROM user_keys WHERE user_id = ? and key_type = ?", + ) + .bind(user_row_id) + .bind("master_key") + .fetch_all(&mut *connection) + .await?; + + let mut keys = BTreeMap::new(); + let mut signatures = BTreeMap::new(); + let mut key_usage = HashSet::new(); + + for row in key_rows { + let key_row_id = row.0; + let key_id = row.1; + let key = row.2; + let usage: Vec = serde_json::from_str(&row.3)?; + + keys.insert(key_id, key); + key_usage.extend(usage); + + let mut signature_rows: Vec<(String, String, String)> = query_as( + "SELECT user_id, key_id, signature, FROM user_key_signatures WHERE user_key = ?", + ) + .bind(user_row_id) + .bind("master_key") + .fetch_all(&mut *connection) + .await?; + + for row in signature_rows.drain(..) { + let user_id = if let Ok(u) = UserId::try_from(row.0) { + u + } else { + continue; + }; + + let key_id = row.1; + let signature = row.2; + + signatures + .entry(user_id) + .or_insert_with(BTreeMap::new) + .insert(key_id, signature); + } + } + + let usage: Vec = key_usage + .iter() + .filter_map(|u| serde_json::from_str(u).ok()) + .collect(); + + let key = CrossSigningKey { + user_id: user_id.to_owned(), + usage, + keys, + signatures, + }; + + Ok(None) + } + + async fn save_user_helper(&self, user: &UserIdentities) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + let mut connection = self.connection.lock().await; + + query( + "INSERT OR IGNORE INTO users ( + account_id, user_id + ) VALUES (?1, ?2) + ", + ) + .bind(account_id) + .bind(user.user_id().as_str()) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM users + WHERE account_id = ? and user_id = ?", + ) + .bind(account_id) + .bind(user.user_id().as_str()) + .fetch_one(&mut *connection) + .await?; + + let user_row_id = row.0; + + for (key_id, key) in user.master_key() { + query( + "INSERT OR IGNORE INTO user_keys ( + user_id, key_type, key_id, key, usage + ) VALUES (?1, ?2, ?3, ?4, ?5) + ", + ) + .bind(user_row_id) + .bind("master_key") + .bind(key_id.as_str()) + .bind(key) + .bind(serde_json::to_string(user.master_key().usage())?) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM user_keys + WHERE user_id = ? and key_id = ? and key_type = ?", + ) + .bind(user_row_id) + .bind(key_id.as_str()) + .bind("master_key") + .fetch_one(&mut *connection) + .await?; + + let key_row_id = row.0; + + for (user_id, signature_map) in user.master_key().signatures() { + for (key_id, signature) in signature_map { + query( + "INSERT OR IGNORE INTO user_key_signatures ( + user_key, user_id, key_id, signature + ) VALUES (?1, ?2, ?3, ?4) + ", + ) + .bind(key_row_id) + .bind(user_id.as_str()) + .bind(key_id.as_str()) + .bind(signature) + .execute(&mut *connection) + .await?; + } + } + } + + Ok(()) + } } #[async_trait] @@ -899,11 +1109,15 @@ impl CryptoStore for SqliteStore { Ok(self.devices.user_devices(user_id)) } - async fn get_user_identity(&self, _user_id: &UserId) -> Result> { - Ok(None) + async fn get_user_identity(&self, user_id: &UserId) -> Result> { + self.load_user(user_id).await } - async fn save_user_identities(&self, _users: &[UserIdentities]) -> Result<()> { + async fn save_user_identities(&self, users: &[UserIdentities]) -> Result<()> { + for user in users { + self.save_user_helper(user).await?; + } + Ok(()) } } @@ -922,7 +1136,7 @@ impl std::fmt::Debug for SqliteStore { #[cfg(test)] mod test { use crate::{ - identities::device::test::get_device, + identities::{device::test::get_device, user::test::get_own_identity}, olm::{Account, GroupSessionKey, InboundGroupSession, Session}, }; use matrix_sdk_common::{ @@ -1311,4 +1525,41 @@ mod test { assert!(loaded_device.is_none()); } + + #[tokio::test] + async fn user_saving() { + let (_account, store, dir) = get_loaded_store().await; + let own_identity = get_own_identity(); + + store + .save_user_identities(&[own_identity.into()]) + .await + .expect("Can't save identity"); + + drop(store); + + // let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) + // .await + // .expect("Can't create store"); + + // store.load_account().await.unwrap(); + + // let loaded_device = store + // .get_device(device.user_id(), device.device_id()) + // .await + // .unwrap() + // .unwrap(); + + // assert_eq!(device, loaded_device); + + // for algorithm in loaded_device.algorithms() { + // assert!(device.algorithms().contains(algorithm)); + // } + // assert_eq!(device.algorithms().len(), loaded_device.algorithms().len()); + // assert_eq!(device.keys(), loaded_device.keys()); + + // let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); + // assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); + // assert_eq!(user_devices.devices().next().unwrap(), &device); + } }