diff --git a/src/crypto/device.rs b/src/crypto/device.rs index c5d50a32..cfde349c 100644 --- a/src/crypto/device.rs +++ b/src/crypto/device.rs @@ -18,9 +18,9 @@ use std::sync::Arc; use atomic::Atomic; -use ruma_client_api::r0::keys::{DeviceKeys, KeyAlgorithm}; -use ruma_events::Algorithm; -use ruma_identifiers::{DeviceId, UserId}; +use crate::api::r0::keys::{DeviceKeys, KeyAlgorithm}; +use crate::events::Algorithm; +use crate::identifiers::{DeviceId, UserId}; #[derive(Debug, Clone)] pub struct Device { @@ -82,3 +82,61 @@ impl From<&DeviceKeys> for Device { } } } + +impl PartialEq for Device { + fn eq(&self, other: &Self) -> bool { + self.user_id() == other.user_id() && self.device_id() == other.device_id() + } +} + +#[cfg(test)] +pub(crate) mod test { + use serde_json::json; + use std::convert::{From, TryFrom}; + + use crate::api::r0::keys::DeviceKeys; + use crate::crypto::device::Device; + use crate::identifiers::UserId; + + pub(crate) fn get_device() -> Device { + let user_id = UserId::try_from("@alice:example.org").unwrap(); + let device_id = "DEVICEID"; + + let device_keys = json!({ + "algorithms": vec![ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "device_id": device_id, + "user_id": user_id.to_string(), + "keys": { + "curve25519:DEVICEID": "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4", + "ed25519:DEVICEID": "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM" + }, + "signatures": { + user_id.to_string(): { + "ed25519:DEVICEID": "m53Wkbh2HXkc3vFApZvCrfXcX3AI51GsDHustMhKwlv3TuOJMj4wistcOTM8q2+e/Ro7rWFUb9ZfnNbwptSUBA" + } + }, + "unsigned": { + "device_display_name": "Alice's mobile phone" + } + }); + + let device_keys: DeviceKeys = serde_json::from_value(device_keys).unwrap(); + + Device::from(&device_keys) + } + + #[test] + fn create_a_device() { + let user_id = UserId::try_from("@alice:example.org").unwrap(); + let device_id = "DEVICEID"; + + let device = get_device(); + + assert_eq!(&user_id, device.user_id()); + assert_eq!(device_id, device.device_id()); + assert_eq!(device.algorithms.len(), 2); + } +} diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 2d9f5e11..b0e39bb6 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -114,36 +114,45 @@ impl GroupSessionStore { } } +/// In-memory store holding the devices of users. #[derive(Clone, Debug)] pub struct DeviceStore { entries: Arc>>, } +/// A read only view over all devices belonging to a user. pub struct UserDevices { entries: ReadOnlyView, } impl UserDevices { + /// Get the specific device with the given device id. pub fn get(&self, device_id: &str) -> Option { self.entries.get(device_id).cloned() } + /// Iterator over all the device ids of the user devices. pub fn keys(&self) -> impl Iterator { self.entries.keys() } + /// Iterator over all the devices of the user devices. pub fn devices(&self) -> impl Iterator { self.entries.values() } } impl DeviceStore { + /// Create a new empty device store. pub fn new() -> Self { DeviceStore { entries: Arc::new(DashMap::new()), } } + /// Add a device to the store. + /// + /// Returns true if the device was already in the store, false otherwise. pub fn add(&self, device: Device) -> bool { let user_id = device.user_id(); @@ -157,12 +166,14 @@ impl DeviceStore { .is_some() } + /// Get the device with the given device_id and belonging to the given user. pub fn get(&self, user_id: &UserId, device_id: &str) -> Option { self.entries .get(user_id) .and_then(|m| m.get(device_id).map(|d| d.value().clone())) } + /// Get a read-only view over all devices of the given user. pub fn user_devices(&self, user_id: &UserId) -> UserDevices { if !self.entries.contains_key(user_id) { self.entries.insert(user_id.clone(), DashMap::new()); @@ -179,6 +190,7 @@ mod test { use std::convert::TryFrom; use crate::api::r0::keys::SignedKey; + use crate::crypto::device::test::get_device; use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore}; use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; use crate::identifiers::RoomId; @@ -226,6 +238,21 @@ mod test { assert_eq!(&session, loaded_session); } + #[tokio::test] + async fn test_session_store_bulk_storing() { + let (account, session) = get_account_and_session().await; + + let mut store = SessionStore::new(); + store.set_for_sender(&session.sender_key, vec![session.clone()]); + + let sessions = store.get(&session.sender_key).unwrap(); + let sessions = sessions.lock().await; + + let loaded_session = &sessions[0]; + + assert_eq!(&session, loaded_session); + } + #[tokio::test] async fn test_group_session_store() { let alice = Account::new(); @@ -254,4 +281,26 @@ mod test { .unwrap(); assert_eq!(inbound, loaded_session); } + + #[tokio::test] + async fn test_device_store() { + let device = get_device(); + let store = DeviceStore::new(); + + assert!(!store.add(device.clone())); + assert!(store.add(device.clone())); + + let loaded_device = store.get(device.user_id(), device.device_id()).unwrap(); + + assert_eq!(device, loaded_device); + + let user_devices = store.user_devices(device.user_id()); + + assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); + assert_eq!(user_devices.devices().nth(0).unwrap(), &device); + + let loaded_device = user_devices.get(device.device_id()).unwrap(); + + assert_eq!(device, loaded_device); + } }