diff --git a/src/async_client.rs b/src/async_client.rs index cdd12870..fae92458 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -722,7 +722,6 @@ impl AsyncClient { callback(response).await; - // TODO query keys here. // TODO send out to-device messages here #[cfg(feature = "encryption")] diff --git a/src/crypto/device.rs b/src/crypto/device.rs index f6c34a26..d0b2fcf1 100644 --- a/src/crypto/device.rs +++ b/src/crypto/device.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::mem; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -26,8 +27,6 @@ use crate::identifiers::{DeviceId, UserId}; pub struct Device { user_id: Arc, device_id: Arc, - // TODO the algorithm and the keys might change, so we can't make them read - // only here. Perhaps dashmap and a rwlock on the algorithms. algorithms: Arc>, keys: Arc>, display_name: Arc>, @@ -115,6 +114,40 @@ impl Device { pub fn algorithms(&self) -> &[Algorithm] { &self.algorithms } + + /// Is the device deleted. + pub fn deleted(&self) -> bool { + self.deleted.load(Ordering::Relaxed) + } + + /// Update a device with a new device keys struct. + pub(crate) fn update_device(&mut self, device_keys: &DeviceKeys) { + let mut keys = HashMap::new(); + + for (key_id, key) in device_keys.keys.iter() { + let key_id = key_id.0; + keys.insert(key_id, key.clone()); + } + + let display_name = Arc::new( + device_keys + .unsigned + .as_ref() + .map(|d| d.device_display_name.clone()), + ); + + mem::replace( + &mut self.algorithms, + Arc::new(device_keys.algorithms.clone()), + ); + mem::replace(&mut self.keys, Arc::new(keys)); + mem::replace(&mut self.display_name, display_name); + } + + /// Mark the device as deleted. + pub(crate) fn mark_as_deleted(&self) { + self.deleted.store(true, Ordering::Relaxed); + } } impl From<&DeviceKeys> for Device { @@ -158,7 +191,7 @@ pub(crate) mod test { use crate::crypto::device::{Device, TrustState}; use crate::identifiers::UserId; - pub(crate) fn get_device() -> Device { + fn device_keys() -> DeviceKeys { let user_id = UserId::try_from("@alice:example.org").unwrap(); let device_id = "DEVICEID"; @@ -183,8 +216,11 @@ pub(crate) mod test { } }); - let device_keys: DeviceKeys = serde_json::from_value(device_keys).unwrap(); + serde_json::from_value(device_keys).unwrap() + } + pub(crate) fn get_device() -> Device { + let device_keys = device_keys(); Device::from(&device_keys) } @@ -212,4 +248,36 @@ pub(crate) mod test { "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM" ); } + + #[test] + fn update_a_device() { + let mut device = get_device(); + + assert_eq!( + "Alice's mobile phone", + device.display_name().as_ref().unwrap() + ); + + let mut device_keys = device_keys(); + device_keys.unsigned.as_mut().unwrap().device_display_name = + "Alice's work computer".to_owned(); + device.update_device(&device_keys); + + assert_eq!( + "Alice's work computer", + device.display_name().as_ref().unwrap() + ); + } + + #[test] + fn delete_a_device() { + let device = get_device(); + assert!(!device.deleted()); + + let device_clone = device.clone(); + + device.mark_as_deleted(); + assert!(device.deleted()); + assert!(device_clone.deleted()); + } } diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 73f230f6..52c6a39c 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -342,15 +342,17 @@ impl OlmMachine { /// Receive a successful keys query response. /// + /// Returns a list of devices newly discovered devices and devices that + /// changed. + /// /// # Arguments /// /// * `response` - The keys query response of the request that the client /// performed. - // TODO this should return a list of changed devices. pub async fn receive_keys_query_response( &mut self, response: &keys::get_keys::Response, - ) -> Result<()> { + ) -> Result> { let mut changed_devices = Vec::new(); for (user_id, device_map) in &response.device_keys { @@ -370,20 +372,15 @@ impl OlmMachine { continue; } - // let curve_key_id = - // AlgorithmAndDeviceId(KeyAlgorithm::Curve25519, device_id.to_owned()); let ed_key_id = AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, device_id.to_owned()); - // TODO check if the curve key changed for an existing device. - // let sender_key = if let Some(k) = device_keys.keys.get(&curve_key_id) { - // k - // } else { - // continue; - // }; - let signing_key = if let Some(k) = device_keys.keys.get(&ed_key_id) { k } else { + warn!( + "Ed25519 identity key wasn't found for user/device {} {}", + user_id, device_id + ); continue; }; @@ -398,19 +395,28 @@ impl OlmMachine { continue; } - let device = self - .store - .get_device(&user_id, device_id) - .await - .expect("Can't load device"); + let device = self.store.get_device(&user_id, device_id).await?; - if let Some(_d) = device { - // TODO check what and if anything changed for the device. + let device = if let Some(mut d) = device { + let stored_signing_key = d.get_key(&KeyAlgorithm::Ed25519); + + if let Some(stored_signing_key) = stored_signing_key { + if stored_signing_key != signing_key { + warn!("Ed25519 key has changed for {} {}", user_id, device_id); + continue; + } + } + + d.update_device(device_keys); + + d } else { let device = Device::from(device_keys); - info!("Found new device {:?}", device); - changed_devices.push(device); - } + info!("Adding a new device to the device store {:?}", device); + device + }; + + changed_devices.push(device); } let current_devices: HashSet<&DeviceId> = device_map.keys().collect(); @@ -419,16 +425,20 @@ impl OlmMachine { let deleted_devices = stored_devices_set.difference(¤t_devices); - for _device_id in deleted_devices { - // TODO delete devices here. + for device_id in deleted_devices { + if let Some(device) = stored_devices.get(device_id) { + device.mark_as_deleted(); + // TODO change this to a delete device. + self.store.save_device(device).await?; + } } } - for device in changed_devices { - self.store.save_device(device).await.unwrap(); + for device in &changed_devices { + self.store.save_device(device.clone()).await?; } - Ok(()) + Ok(changed_devices) } /// Generate new one-time keys. @@ -1238,6 +1248,27 @@ mod test { keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response") } + fn keys_query_response() -> keys::get_keys::Response { + let data = response_from_file("tests/data/keys_query.json"); + keys::get_keys::Response::try_from(data).expect("Can't parse the keys upload response") + } + + async fn get_prepared_machine() -> OlmMachine { + let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); + machine.uploaded_signed_key_count = Some(0); + let (_, _) = machine + .keys_for_upload() + .await + .expect("Can't prepare initial key upload"); + let response = keys_upload_response(); + machine + .receive_keys_upload_response(&response) + .await + .unwrap(); + + machine + } + #[tokio::test] async fn create_olm_machine() { let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); @@ -1404,4 +1435,29 @@ mod test { let ret = machine.keys_for_upload().await; assert!(ret.is_err()); } + + #[tokio::test] + async fn test_keys_query() { + let mut machine = get_prepared_machine().await; + let response = keys_query_response(); + let alice_id = UserId::try_from("@alice:example.org").unwrap(); + let alice_device_id = "JLAFKJWSCS".to_owned(); + + let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap(); + assert!(alice_devices.devices().peekable().peek().is_none()); + + machine + .receive_keys_query_response(&response) + .await + .unwrap(); + + let device = machine + .store + .get_device(&alice_id, &alice_device_id) + .await + .unwrap() + .unwrap(); + assert_eq!(device.user_id(), &alice_id); + assert_eq!(device.device_id(), &alice_device_id); + } } diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 2510edb9..ba1cd6f4 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -185,6 +185,16 @@ impl DeviceStore { .and_then(|m| m.get(device_id).map(|d| d.value().clone())) } + /// Remove the device with the given device_id and belonging to the given user. + /// + /// Returns the device if it was removed, None if it wasn't in the store. + pub fn remove(&self, user_id: &UserId, device_id: &str) -> Option { + self.entries + .get(user_id) + .and_then(|m| m.remove(device_id)) + .and_then(|(_, d)| Some(d)) + } + /// Get a read-only view over all devices of the given user. pub fn user_devices(&self, user_id: &UserId) -> UserDevices { if !self.entries.contains_key(user_id) { @@ -286,5 +296,10 @@ mod test { let loaded_device = user_devices.get(device.device_id()).unwrap(); assert_eq!(device, loaded_device); + + store.remove(device.user_id(), device.device_id()); + + let loaded_device = store.get(device.user_id(), device.device_id()); + assert!(loaded_device.is_none()); } } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 237ea913..df7f336d 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -609,6 +609,7 @@ impl CryptoStore for SqliteStore { } async fn save_device(&self, device: Device) -> Result<()> { + self.devices.add(device.clone()); self.save_device_helper(device).await }