diff --git a/Cargo.toml b/Cargo.toml index 6a3b516c..46ea2007 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,8 @@ zeroize = { version = "1.1.0", optional = true } thiserror = "1.0.13" async-trait = { version = "0.1.26", optional = true } tracing = "0.1.13" +atomic = "0.4.5" +dashmap = "3.9.1" [dependencies.tracing-futures] version = "0.2.3" diff --git a/src/crypto/device.rs b/src/crypto/device.rs index fe8ff498..9b923b3d 100644 --- a/src/crypto/device.rs +++ b/src/crypto/device.rs @@ -13,22 +13,26 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use atomic::Atomic; use ruma_client_api::r0::keys::{DeviceKeys, KeyAlgorithm}; use ruma_events::Algorithm; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Device { - user_id: String, - device_id: String, - algorithms: Vec, - keys: HashMap, - display_name: Option, - deleted: bool, - trust_state: TrustState, + user_id: Arc, + device_id: Arc, + algorithms: Arc>, + keys: Arc>, + display_name: Arc>, + deleted: Arc, + trust_state: Arc>, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum TrustState { Verified, BlackListed, @@ -37,7 +41,7 @@ pub enum TrustState { } impl Device { - pub fn id(&self) -> &str { + pub fn device_id(&self) -> &str { &self.device_id } @@ -56,16 +60,18 @@ impl From<&DeviceKeys> for Device { } Device { - user_id: device_keys.user_id.to_string(), - device_id: device_keys.device_id.clone(), - algorithms: device_keys.algorithms.clone(), - keys, - display_name: device_keys - .unsigned - .as_ref() - .map(|d| d.device_display_name.clone()), - deleted: false, - trust_state: TrustState::Unset, + user_id: Arc::new(device_keys.user_id.to_string()), + device_id: Arc::new(device_keys.device_id.clone()), + algorithms: Arc::new(device_keys.algorithms.clone()), + keys: Arc::new(keys), + display_name: Arc::new( + device_keys + .unsigned + .as_ref() + .map(|d| d.device_display_name.clone()), + ), + deleted: Arc::new(AtomicBool::new(false)), + trust_state: Arc::new(Atomic::new(TrustState::Unset)), } } } diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index abfd8196..241378b0 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -240,17 +240,27 @@ impl OlmMachine { let device = self .store - .get_user_device(&user_id_string, device_id) + .get_device(&user_id_string, device_id) .await .expect("Can't load device"); if let Some(d) = device { - todo!() + // TODO check what and if anything changed for the device. } else { let device = Device::from(device_keys); info!("Found new device {:?}", device); } } + + let current_devices: HashSet<&String> = device_map.keys().collect(); + let stored_devices = self.store.get_user_devices(&user_id_string).await.unwrap(); + let stored_devices_set: HashSet<&String> = stored_devices.keys().collect(); + + let deleted_devices = stored_devices_set.difference(¤t_devices); + + for device_id in deleted_devices { + // TODO delete devices here. + } } Ok(()) } diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index e86961a1..cfdb8799 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -15,8 +15,10 @@ use std::collections::HashMap; use std::sync::Arc; +use dashmap::{DashMap, ReadOnlyView}; use tokio::sync::Mutex; +use super::device::Device; use super::olm::{InboundGroupSession, Session}; #[derive(Debug)] @@ -96,3 +98,57 @@ impl GroupSessionStore { .and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned())) } } + +#[derive(Debug)] +pub struct DeviceStore { + entries: DashMap>, +} + +pub struct UserDevices { + entries: ReadOnlyView, +} + +impl UserDevices { + pub fn get(&self, device_id: &str) -> Option { + self.entries.get(device_id).cloned() + } + + pub fn keys(&self) -> impl Iterator { + self.entries.keys() + } +} + +impl DeviceStore { + pub fn new() -> Self { + DeviceStore { + entries: DashMap::new(), + } + } + + pub fn add(&self, device: Device) -> bool { + if !self.entries.contains_key(device.user_id()) { + self.entries + .insert(device.user_id().to_owned(), DashMap::new()); + } + let mut device_map = self.entries.get_mut(device.user_id()).unwrap(); + + device_map + .insert(device.device_id().to_owned(), device) + .is_some() + } + + pub fn get(&self, user_id: &str, device_id: &str) -> Option { + self.entries + .get(user_id) + .and_then(|m| m.get(device_id).map(|d| d.value().clone())) + } + + pub fn user_devices(&self, user_id: &str) -> UserDevices { + if !self.entries.contains_key(user_id) { + self.entries.insert(user_id.to_owned(), DashMap::new()); + } + UserDevices { + entries: self.entries.get(user_id).unwrap().clone().into_read_only(), + } + } +} diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 2c522cf7..7084119c 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -20,13 +20,14 @@ use tokio::sync::Mutex; use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; use crate::crypto::device::Device; -use crate::crypto::memory_stores::{GroupSessionStore, SessionStore}; +use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; #[derive(Debug)] pub struct MemoryStore { sessions: SessionStore, inbound_group_sessions: GroupSessionStore, tracked_users: HashSet, + devices: DeviceStore, } impl MemoryStore { @@ -35,6 +36,7 @@ impl MemoryStore { sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), tracked_users: HashSet::new(), + devices: DeviceStore::new(), } } } @@ -88,7 +90,11 @@ impl CryptoStore for MemoryStore { Ok(self.tracked_users.insert(user.to_string())) } - async fn get_user_device(&self, user_id: &str, device_id: &str) -> Result> { - Ok(None) + async fn get_device(&self, user_id: &str, device_id: &str) -> Result> { + Ok(self.devices.get(user_id, device_id)) + } + + async fn get_user_devices(&self, user_id: &str) -> Result { + Ok(self.devices.user_devices(user_id)) } } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index 718116b9..c1f6a76c 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -25,6 +25,7 @@ use thiserror::Error; use tokio::sync::Mutex; use super::device::Device; +use super::memory_stores::UserDevices; use super::olm::{Account, InboundGroupSession, Session}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::PicklingMode; @@ -82,5 +83,6 @@ pub trait CryptoStore: Debug + Send + Sync { ) -> Result>>>; fn tracked_users(&self) -> &HashSet; async fn add_user_for_tracking(&mut self, user: &str) -> Result; - async fn get_user_device(&self, user_id: &str, device_id: &str) -> Result>; + async fn get_device(&self, user_id: &str, device_id: &str) -> Result>; + async fn get_user_devices(&self, user_id: &str) -> Result; } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index bf2117de..5adce8f8 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -28,7 +28,7 @@ use zeroize::Zeroizing; use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; use crate::crypto::device::Device; -use crate::crypto::memory_stores::{GroupSessionStore, SessionStore}; +use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices}; pub struct SqliteStore { user_id: Arc, @@ -410,7 +410,11 @@ impl CryptoStore for SqliteStore { Ok(self.tracked_users.insert(user.to_string())) } - async fn get_user_device(&self, user_id: &str, device_id: &str) -> Result> { + async fn get_device(&self, user_id: &str, device_id: &str) -> Result> { + todo!() + } + + async fn get_user_devices(&self, user_id: &str) -> Result { todo!() } }