diff --git a/src/crypto/device.rs b/src/crypto/device.rs index cf1ddf49..c5d50a32 100644 --- a/src/crypto/device.rs +++ b/src/crypto/device.rs @@ -20,11 +20,12 @@ use atomic::Atomic; use ruma_client_api::r0::keys::{DeviceKeys, KeyAlgorithm}; use ruma_events::Algorithm; +use ruma_identifiers::{DeviceId, UserId}; #[derive(Debug, Clone)] pub struct Device { - user_id: Arc, - device_id: Arc, + user_id: Arc, + device_id: Arc, algorithms: Arc>, keys: Arc>, display_name: Arc>, @@ -47,11 +48,11 @@ pub enum TrustState { } impl Device { - pub fn device_id(&self) -> &str { + pub fn device_id(&self) -> &DeviceId { &self.device_id } - pub fn user_id(&self) -> &str { + pub fn user_id(&self) -> &UserId { &self.user_id } } @@ -66,7 +67,7 @@ impl From<&DeviceKeys> for Device { } Device { - user_id: Arc::new(device_keys.user_id.to_string()), + user_id: Arc::new(device_keys.user_id.clone()), device_id: Arc::new(device_keys.device_id.clone()), algorithms: Arc::new(device_keys.algorithms.clone()), keys: Arc::new(keys), diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 98554cbf..dd8d49dc 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -20,7 +20,7 @@ use std::result::Result as StdResult; use std::sync::Arc; use super::error::{OlmError, Result, SignatureError, VerificationResult}; -use super::olm::{Account, InboundGroupSession, OutboundGroupSession}; +use super::olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; use super::store::memorystore::MemoryStore; #[cfg(feature = "sqlite-cryptostore")] use super::store::sqlite::SqliteStore; @@ -35,19 +35,25 @@ use serde_json::{json, Value}; use tokio::sync::Mutex; use tracing::{debug, error, info, instrument, trace, warn}; -use ruma_client_api::r0::client_exchange::DeviceIdOrAllDevices; +use ruma_client_api::r0::client_exchange::{ + send_event_to_device::Request as ToDeviceRequest, DeviceIdOrAllDevices, +}; use ruma_client_api::r0::keys::{ AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey, }; use ruma_client_api::r0::sync::sync_events::IncomingResponse as SyncResponse; use ruma_events::{ - collections::all::{Event, RoomEvent}, - room::encrypted::{EncryptedEvent, EncryptedEventContent}, + collections::all::RoomEvent, + room::encrypted::{ + CiphertextInfo, EncryptedEvent, EncryptedEventContent, MegolmV1AesSha2Content, + OlmV1Curve25519AesSha2Content, + }, + room::message::MessageEventContent, to_device::{ AnyToDeviceEvent as ToDeviceEvent, ToDeviceEncrypted, ToDeviceForwardedRoomKey, ToDeviceRoomKey, ToDeviceRoomKeyRequest, }, - Algorithm, EventResult, + Algorithm, EventResult, EventType, }; use ruma_identifiers::RoomId; use ruma_identifiers::{DeviceId, UserId}; @@ -409,9 +415,9 @@ impl OlmMachine { } } - let current_devices: HashSet<&String> = device_map.keys().collect(); + let current_devices: HashSet<&DeviceId> = device_map.keys().collect(); let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); - let stored_devices_set: HashSet<&String> = stored_devices.keys().collect(); + let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); let deleted_devices = stored_devices_set.difference(¤t_devices); diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index dee2378f..de2e67e3 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::collections::HashMap; -use std::convert::TryFrom; use std::sync::Arc; use dashmap::{DashMap, ReadOnlyView}; @@ -21,7 +20,7 @@ use tokio::sync::Mutex; use super::device::Device; use super::olm::{InboundGroupSession, Session}; -use crate::identifiers::{RoomId, UserId}; +use crate::identifiers::{DeviceId, RoomId, UserId}; #[derive(Debug)] pub struct SessionStore { @@ -107,7 +106,7 @@ pub struct DeviceStore { } pub struct UserDevices { - entries: ReadOnlyView, + entries: ReadOnlyView, } impl UserDevices { @@ -115,7 +114,7 @@ impl UserDevices { self.entries.get(device_id).cloned() } - pub fn keys(&self) -> impl Iterator { + pub fn keys(&self) -> impl Iterator { self.entries.keys() } @@ -132,7 +131,7 @@ impl DeviceStore { } pub fn add(&self, device: Device) -> bool { - let user_id = UserId::try_from(device.user_id()).unwrap(); + let user_id = device.user_id(); if !self.entries.contains_key(&user_id) { self.entries.insert(user_id.clone(), DashMap::new()); diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index f8259873..bafd2d56 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -292,6 +292,7 @@ impl OutboundGroupSession { pub fn new(room_id: &RoomId) -> Self { let session = OlmOutboundGroupSession::new(); let session_id = session.session_id(); + OutboundGroupSession { inner: Arc::new(Mutex::new(session)), room_id: Arc::new(room_id.to_owned()),