diff --git a/src/crypto/device.rs b/src/crypto/device.rs index 44bb8e3b..f6c34a26 100644 --- a/src/crypto/device.rs +++ b/src/crypto/device.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::collections::HashMap; -use std::sync::atomic::AtomicBool; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use atomic::Atomic; @@ -26,6 +26,8 @@ use crate::identifiers::{DeviceId, UserId}; pub struct Device { user_id: Arc, device_id: Arc, + // TODO the algorithm and the keys might change, so we can't make them read + // only here. Perhaps dashmap and a rwlock on the algorithms. algorithms: Arc>, keys: Arc>, display_name: Arc>, @@ -33,26 +35,86 @@ pub struct Device { trust_state: Arc>, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] +/// The trust state of a device. pub enum TrustState { - Verified, - BlackListed, - Ignored, - Unset, + /// The device has been verified and is trusted. + Verified = 0, + /// The device been blacklisted from communicating. + BlackListed = 1, + /// The trust state of the device is being ignored. + Ignored = 2, + /// The trust state is unset. + Unset = 3, +} + +impl From for TrustState { + fn from(state: i64) -> Self { + match state { + 0 => TrustState::Verified, + 1 => TrustState::BlackListed, + 2 => TrustState::Ignored, + 3 => TrustState::Unset, + _ => TrustState::Unset, + } + } } impl Device { - pub fn device_id(&self) -> &DeviceId { - &self.device_id + /// Create a new Device. + pub fn new( + user_id: UserId, + device_id: DeviceId, + display_name: Option, + trust_state: TrustState, + algorithms: Vec, + keys: HashMap, + ) -> Self { + Device { + user_id: Arc::new(user_id), + device_id: Arc::new(device_id), + display_name: Arc::new(display_name), + trust_state: Arc::new(Atomic::new(trust_state)), + algorithms: Arc::new(algorithms), + keys: Arc::new(keys), + deleted: Arc::new(AtomicBool::new(false)), + } } + /// The user id of the device owner. pub fn user_id(&self) -> &UserId { &self.user_id } - pub fn keys(&self, algorithm: &KeyAlgorithm) -> Option<&String> { + /// The unique ID of the device. + pub fn device_id(&self) -> &DeviceId { + &self.device_id + } + + /// Get the human readable name of the device. + pub fn display_name(&self) -> &Option { + &self.display_name + } + + /// Get the key of the given key algorithm belonging to this device. + pub fn get_key(&self, algorithm: &KeyAlgorithm) -> Option<&String> { self.keys.get(algorithm) } + + /// Get a map containing all the device keys. + pub fn keys(&self) -> &HashMap { + &self.keys + } + + /// Get the trust state of the device. + pub fn trust_state(&self) -> TrustState { + self.trust_state.load(Ordering::Relaxed) + } + + /// Get the list of algorithms this device supports. + pub fn algorithms(&self) -> &[Algorithm] { + &self.algorithms + } } impl From<&DeviceKeys> for Device { @@ -93,7 +155,7 @@ pub(crate) mod test { use std::convert::{From, TryFrom}; use crate::api::r0::keys::{DeviceKeys, KeyAlgorithm}; - use crate::crypto::device::Device; + use crate::crypto::device::{Device, TrustState}; use crate::identifiers::UserId; pub(crate) fn get_device() -> Device { @@ -136,12 +198,17 @@ pub(crate) mod test { assert_eq!(&user_id, device.user_id()); assert_eq!(device_id, device.device_id()); assert_eq!(device.algorithms.len(), 2); + assert_eq!(TrustState::Unset, device.trust_state()); assert_eq!( - device.keys(&KeyAlgorithm::Curve25519).unwrap(), + "Alice's mobile phone", + device.display_name().as_ref().unwrap() + ); + assert_eq!( + device.get_key(&KeyAlgorithm::Curve25519).unwrap(), "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4" ); assert_eq!( - device.keys(&KeyAlgorithm::Ed25519).unwrap(), + device.get_key(&KeyAlgorithm::Ed25519).unwrap(), "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM" ); } diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 20eeac5d..73f230f6 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -201,7 +201,7 @@ impl OlmMachine { let user_devices = self.store.get_user_devices(user_id).await.unwrap(); for device in user_devices.devices() { - let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { + let sender_key = if let Some(k) = device.get_key(&KeyAlgorithm::Curve25519) { k } else { continue; @@ -276,7 +276,7 @@ impl OlmMachine { continue; }; - let signing_key = if let Some(k) = device.keys(&KeyAlgorithm::Ed25519) { + let signing_key = if let Some(k) = device.get_key(&KeyAlgorithm::Ed25519) { k } else { warn!( @@ -298,7 +298,7 @@ impl OlmMachine { continue; } - let curve_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { + let curve_key = if let Some(k) = device.get_key(&KeyAlgorithm::Curve25519) { k } else { warn!( @@ -865,10 +865,10 @@ impl OlmMachine { let identity_keys = self.account.identity_keys(); let recipient_signing_key = recipient_device - .keys(&KeyAlgorithm::Ed25519) + .get_key(&KeyAlgorithm::Ed25519) .ok_or(OlmError::MissingSigningKey)?; let recipient_sender_key = recipient_device - .keys(&KeyAlgorithm::Curve25519) + .get_key(&KeyAlgorithm::Curve25519) .ok_or(OlmError::MissingSigningKey)?; let payload = json!({ @@ -957,7 +957,7 @@ impl OlmMachine { for user_id in users { for device in self.store.get_user_devices(user_id).await?.devices() { - let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { + let sender_key = if let Some(k) = device.get_key(&KeyAlgorithm::Curve25519) { k } else { warn!( diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 6419361d..574211a2 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; +use std::mem; use std::path::{Path, PathBuf}; use std::result::Result as StdResult; use std::sync::Arc; @@ -28,8 +29,10 @@ use tokio::sync::Mutex; use zeroize::Zeroizing; use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; -use crate::crypto::device::Device; -use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices}; +use crate::api::r0::keys::KeyAlgorithm; +use crate::crypto::device::{Device, TrustState}; +use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; +use crate::events::Algorithm; use crate::identifiers::{DeviceId, RoomId, UserId}; pub struct SqliteStore { @@ -37,11 +40,14 @@ pub struct SqliteStore { device_id: Arc, account_id: Option, path: PathBuf, + sessions: SessionStore, inbound_group_sessions: GroupSessionStore, + devices: DeviceStore, + tracked_users: HashSet, + connection: Arc>, pickle_passphrase: Option>, - tracked_users: HashSet, } static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; @@ -85,6 +91,7 @@ impl SqliteStore { account_id: None, sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), + devices: DeviceStore::new(), path: path.as_ref().to_owned(), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: passphrase, @@ -149,6 +156,61 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS devices ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "user_id" TEXT NOT NULL, + "device_id" TEXT NOT NULL, + "display_name" TEXT, + "trust_state" INTEGER NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id,user_id,device_id) + ); + + CREATE INDEX IF NOT EXISTS "devices_account_id" ON "devices" ("account_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS algorithms ( + "id" INTEGER NOT NULL PRIMARY KEY, + "device_id" INTEGER NOT NULL, + "algorithm" TEXT NOT NULL, + FOREIGN KEY ("device_id") REFERENCES "devices" ("id") + ON DELETE CASCADE + UNIQUE(device_id, algorithm) + ); + + CREATE INDEX IF NOT EXISTS "algorithms_device_id" ON "algorithms" ("device_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS device_keys ( + "id" INTEGER NOT NULL PRIMARY KEY, + "device_id" INTEGER NOT NULL, + "algorithm" TEXT NOT NULL, + "key" TEXT NOT NULL, + FOREIGN KEY ("device_id") REFERENCES "devices" ("id") + ON DELETE CASCADE + UNIQUE(device_id, algorithm) + ); + + CREATE INDEX IF NOT EXISTS "device_keys_device_id" ON "device_keys" ("device_id"); + "#, + ) + .await?; + Ok(()) } @@ -243,6 +305,142 @@ impl SqliteStore { .collect::>>()?) } + async fn load_devices(&self) -> Result { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let rows: Vec<(i64, String, String, Option, i64)> = query_as( + "SELECT id, user_id, device_id, display_name, trust_state + FROM devices WHERE account_id = ?", + ) + .bind(account_id) + .fetch_all(&mut *connection) + .await?; + + let store = DeviceStore::new(); + + for row in rows { + let device_row_id = row.0; + let user_id = if let Ok(u) = UserId::try_from(&row.1 as &str) { + u + } else { + continue; + }; + + let device_id = &row.2.to_string(); + let display_name = &row.3; + let trust_state = TrustState::from(row.4); + + let algorithm_rows: Vec<(String,)> = + query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; + + let algorithms = algorithm_rows + .iter() + .map(|row| Algorithm::from(&row.0 as &str)) + .collect::>(); + + let key_rows: Vec<(String, String)> = + query_as("SELECT algorithm, key FROM device_keys WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; + + let mut keys = HashMap::new(); + + for row in key_rows { + let algorithm = if let Ok(a) = KeyAlgorithm::try_from(&row.0 as &str) { + a + } else { + continue; + }; + + let key = &row.1; + + keys.insert(algorithm, key.to_owned()); + } + + let device = Device::new( + user_id, + device_id.to_owned(), + display_name.clone(), + trust_state, + algorithms, + keys, + ); + + store.add(device); + } + + Ok(store) + } + + async fn save_device_helper(&self, device: Device) -> Result<()> { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + + let mut connection = self.connection.lock().await; + + query( + "INSERT INTO devices ( + account_id, user_id, device_id, + display_name, trust_state + ) VALUES (?1, ?2, ?3, ?4, ?5) + ON CONFLICT(account_id, user_id, device_id) DO UPDATE SET + display_name = excluded.display_name, + trust_state = excluded.trust_state + ", + ) + .bind(account_id) + .bind(&device.user_id().to_string()) + .bind(device.device_id()) + .bind(device.display_name()) + .bind(device.trust_state() as i64) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM devices + WHERE user_id = ? and device_id = ?", + ) + .bind(&device.user_id().to_string()) + .bind(device.device_id()) + .fetch_one(&mut *connection) + .await?; + + let device_row_id = row.0; + + for algorithm in device.algorithms() { + query( + "INSERT OR IGNORE INTO algorithms ( + device_id, algorithm + ) VALUES (?1, ?2) + ", + ) + .bind(device_row_id) + .bind(algorithm.to_string()) + .execute(&mut *connection) + .await?; + } + + for (key_algorithm, key) in device.keys() { + query( + "INSERT OR IGNORE INTO device_keys ( + device_id, algorithm, key + ) VALUES (?1, ?2, ?3) + ", + ) + .bind(device_row_id) + .bind(key_algorithm.to_string()) + .bind(key) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + fn get_pickle_mode(&self) -> PicklingMode { match &self.pickle_passphrase { Some(p) => PicklingMode::Encrypted { @@ -289,6 +487,9 @@ impl CryptoStore for SqliteStore { }) .collect::<()>(); + let devices = self.load_devices().await?; + mem::replace(&mut self.devices, devices); + // TODO load the tracked users here as well. Ok(result) @@ -303,10 +504,8 @@ impl CryptoStore for SqliteStore { user_id, device_id, pickle, shared ) VALUES (?1, ?2, ?3, ?4) ON CONFLICT(user_id, device_id) DO UPDATE SET - pickle = ?3, - shared = ?4 - WHERE user_id = ?1 and - device_id = ?2 + pickle = excluded.pickle, + shared = excluded.shared ", ) .bind(&*self.user_id.to_string()) @@ -374,8 +573,7 @@ impl CryptoStore for SqliteStore { room_id, pickle ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) ON CONFLICT(session_id) DO UPDATE SET - pickle = ?6 - WHERE session_id = ?1 + pickle = excluded.pickle ", ) .bind(session_id) @@ -410,16 +608,16 @@ impl CryptoStore for SqliteStore { Ok(self.tracked_users.insert(user.clone())) } - async fn get_device(&self, _user_id: &UserId, _device_id: &DeviceId) -> Result> { - todo!() + async fn save_device(&self, device: Device) -> Result<()> { + self.save_device_helper(device).await } - async fn get_user_devices(&self, _user_id: &UserId) -> Result { - todo!() + async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { + Ok(self.devices.get(user_id, device_id)) } - async fn save_device(&self, _device: Device) -> Result<()> { - todo!() + async fn get_user_devices(&self, user_id: &UserId) -> Result { + Ok(self.devices.user_devices(user_id)) } } @@ -461,7 +659,7 @@ mod test { &user_id, DEVICE_ID, tmpdir_path, - "secret".to_string(), + passphrase.to_owned(), ) .await .expect("Can't create a passphrase protected store") @@ -718,7 +916,7 @@ mod test { #[tokio::test] async fn test_tracked_users() { - let (account, mut store, _dir) = get_loaded_store().await; + let (_account, mut store, _dir) = get_loaded_store().await; let device = get_device(); assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); @@ -728,4 +926,35 @@ mod test { tracked_users.contains(device.user_id()); } + + #[tokio::test] + async fn device_saving() { + let (_account, store, dir) = get_loaded_store().await; + let device = get_device(); + + store.save_device(device.clone()).await.unwrap(); + + drop(store); + + let mut store = + SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path()) + .await + .expect("Can't create store"); + + store.load_account().await.unwrap(); + + let loaded_device = store + .get_device(device.user_id(), device.device_id()) + .await + .unwrap() + .unwrap(); + + assert_eq!(device, loaded_device); + + for algorithm in loaded_device.algorithms() { + assert!(device.algorithms().contains(algorithm)); + } + assert_eq!(device.algorithms().len(), loaded_device.algorithms().len()); + assert_eq!(device.keys(), loaded_device.keys()); + } }