diff --git a/src/async_client.rs b/src/async_client.rs index d32b71f5..c7a01e69 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -623,7 +623,7 @@ impl AsyncClient { /// * `data` - The content of the message. pub async fn room_send( &mut self, - room_id: &str, + room_id: &RoomId, data: MessageEventContent, ) -> Result { #[cfg(feature = "encryption")] @@ -658,7 +658,7 @@ impl AsyncClient { } let request = create_message_event::Request { - room_id: RoomId::try_from(room_id).unwrap(), + room_id: room_id.clone(), event_type: EventType::RoomMessage, txn_id: self.transaction_id().to_string(), data, @@ -771,7 +771,7 @@ impl AsyncClient { let mut device_keys: HashMap> = HashMap::new(); for user in users_for_query.drain() { - device_keys.insert(UserId::try_from(user.as_ref()).unwrap(), Vec::new()); + device_keys.insert(user, Vec::new()); } let request = get_keys::Request { diff --git a/src/base_client.rs b/src/base_client.rs index d8930ac6..bc18efca 100644 --- a/src/base_client.rs +++ b/src/base_client.rs @@ -400,8 +400,8 @@ impl Client { #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] pub async fn get_missing_sessions( &self, - users: impl Iterator, - ) -> HashMap> { + users: impl Iterator, + ) -> HashMap> { let mut olm = self.olm.lock().await; match &mut *olm { @@ -431,7 +431,7 @@ impl Client { /// Returns an empty error if no keys need to be queried. #[cfg(feature = "encryption")] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] - pub async fn users_for_key_query(&self) -> StdResult, ()> { + pub async fn users_for_key_query(&self) -> StdResult, ()> { let olm = self.olm.lock().await; match &*olm { diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 48d678ab..7c673a1a 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -71,7 +71,7 @@ pub struct OlmMachine { store: Box, /// Set of users that we need to query keys for. This is a subset of /// the tracked users in the CryptoStore. - users_for_key_query: HashSet, + users_for_key_query: HashSet, } impl OlmMachine { @@ -101,8 +101,7 @@ impl OlmMachine { passphrase: String, ) -> Result { let mut store = - SqliteStore::open_with_passphrase(&user_id.to_string(), device_id, path, passphrase) - .await?; + SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?; let account = match store.load_account().await? { Some(a) => { @@ -183,12 +182,12 @@ impl OlmMachine { pub async fn get_missing_sessions( &mut self, - users: impl Iterator, + users: impl Iterator, ) -> HashMap> { let mut missing = HashMap::new(); for user_id in users { - let user_devices = self.store.get_user_devices(&user_id).await.unwrap(); + let user_devices = self.store.get_user_devices(user_id).await.unwrap(); for device in user_devices.devices() { let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { @@ -206,12 +205,11 @@ impl OlmMachine { }; if is_missing { - let user_id = UserId::try_from(user_id.as_ref()).unwrap(); - if !missing.contains_key(&user_id) { - missing.insert(user_id.to_owned(), HashMap::new()); + if !missing.contains_key(user_id) { + missing.insert(user_id.clone(), HashMap::new()); } - let user_map = missing.get_mut(&user_id).unwrap(); + let user_map = missing.get_mut(user_id).unwrap(); user_map.insert( device.device_id().to_owned(), KeyAlgorithm::SignedCurve25519, @@ -233,7 +231,7 @@ impl OlmMachine { for (device_id, key_map) in user_devices { let device = if let Some(d) = self .store - .get_device(&user_id.to_string(), device_id) + .get_device(&user_id, device_id) .await .expect("Can't get devices") { @@ -346,8 +344,7 @@ impl OlmMachine { let mut changed_devices = Vec::new(); for (user_id, device_map) in &response.device_keys { - let user_id_string = user_id.to_string(); - self.users_for_key_query.remove(&user_id_string); + self.users_for_key_query.remove(&user_id); for (device_id, device_keys) in device_map.iter() { // We don't need our own device in the device store. @@ -393,7 +390,7 @@ impl OlmMachine { let device = self .store - .get_device(&user_id_string, device_id) + .get_device(&user_id, device_id) .await .expect("Can't load device"); @@ -407,7 +404,7 @@ impl OlmMachine { } let current_devices: HashSet<&String> = device_map.keys().collect(); - let stored_devices = self.store.get_user_devices(&user_id_string).await.unwrap(); + let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); let stored_devices_set: HashSet<&String> = stored_devices.keys().collect(); let deleted_devices = stored_devices_set.difference(¤t_devices); @@ -767,7 +764,7 @@ impl OlmMachine { let session = InboundGroupSession::new( sender_key, signing_key, - &event.content.room_id.to_string(), + &event.content.room_id, &event.content.session_key, )?; self.store.save_inbound_group_session(session).await?; @@ -893,11 +890,7 @@ impl OlmMachine { let session = self .store - .get_inbound_group_session( - &room_id.to_string(), - &content.sender_key, - &content.session_id, - ) + .get_inbound_group_session(&room_id, &content.sender_key, &content.session_id) .await?; // TODO check if the olm session is wedged and re-request the key. let session = session.ok_or(OlmError::MissingSession)?; @@ -936,7 +929,7 @@ impl OlmMachine { /// Use the `mark_user_as_changed()` if the user really needs a key query. pub async fn update_tracked_users<'a, I>(&mut self, users: I) where - I: IntoIterator, + I: IntoIterator, { for user in users { let ret = self.store.add_user_for_tracking(user).await; @@ -944,12 +937,12 @@ impl OlmMachine { match ret { Ok(newly_added) => { if newly_added { - self.users_for_key_query.insert(user.to_string()); + self.users_for_key_query.insert(user.clone()); } } Err(e) => { warn!("Error storing users for tracking {}", e); - self.users_for_key_query.insert(user.to_string()); + self.users_for_key_query.insert(user.clone()); } } } @@ -961,7 +954,7 @@ impl OlmMachine { } /// Get the set of users that we need to query keys for. - pub fn users_for_key_query(&self) -> HashSet { + pub fn users_for_key_query(&self) -> HashSet { self.users_for_key_query.clone() } } diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index cffd5187..dee2378f 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::convert::TryFrom; use std::sync::Arc; use dashmap::{DashMap, ReadOnlyView}; @@ -20,6 +21,7 @@ use tokio::sync::Mutex; use super::device::Device; use super::olm::{InboundGroupSession, Session}; +use crate::identifiers::{RoomId, UserId}; #[derive(Debug)] pub struct SessionStore { @@ -59,7 +61,7 @@ impl SessionStore { #[derive(Debug)] pub struct GroupSessionStore { - entries: HashMap>>>>, + entries: HashMap>>>>, } impl GroupSessionStore { @@ -89,7 +91,7 @@ impl GroupSessionStore { pub fn get( &self, - room_id: &str, + room_id: &RoomId, sender_key: &str, session_id: &str, ) -> Option>> { @@ -101,7 +103,7 @@ impl GroupSessionStore { #[derive(Clone, Debug)] pub struct DeviceStore { - entries: Arc>>, + entries: Arc>>, } pub struct UserDevices { @@ -130,26 +132,27 @@ impl DeviceStore { } pub fn add(&self, device: Device) -> bool { - if !self.entries.contains_key(device.user_id()) { - self.entries - .insert(device.user_id().to_owned(), DashMap::new()); + let user_id = UserId::try_from(device.user_id()).unwrap(); + + if !self.entries.contains_key(&user_id) { + self.entries.insert(user_id.clone(), DashMap::new()); } - let device_map = self.entries.get_mut(device.user_id()).unwrap(); + let device_map = self.entries.get_mut(&user_id).unwrap(); device_map .insert(device.device_id().to_owned(), device) .is_some() } - pub fn get(&self, user_id: &str, device_id: &str) -> Option { + pub fn get(&self, user_id: &UserId, device_id: &str) -> Option { self.entries .get(user_id) .and_then(|m| m.get(device_id).map(|d| d.value().clone())) } - pub fn user_devices(&self, user_id: &str) -> UserDevices { + pub fn user_devices(&self, user_id: &UserId) -> UserDevices { if !self.entries.contains_key(user_id) { - self.entries.insert(user_id.to_owned(), DashMap::new()); + self.entries.insert(user_id.clone(), DashMap::new()); } UserDevices { entries: self.entries.get(user_id).unwrap().clone().into_read_only(), diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 9ea5b3fe..659617cd 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -23,6 +23,8 @@ use olm_rs::PicklingMode; use ruma_client_api::r0::keys::SignedKey; +use crate::identifiers::{RoomId, UserId}; + pub struct Account { inner: OlmAccount, pub(crate) shared: bool, @@ -210,7 +212,7 @@ pub struct InboundGroupSession { inner: OlmInboundGroupSession, pub(crate) sender_key: String, pub(crate) signing_key: String, - pub(crate) room_id: String, + pub(crate) room_id: RoomId, forwarding_chains: Option>, } @@ -218,14 +220,14 @@ impl InboundGroupSession { pub fn new( sender_key: &str, signing_key: &str, - room_id: &str, + room_id: &RoomId, session_key: &str, ) -> Result { Ok(InboundGroupSession { inner: OlmInboundGroupSession::new(session_key)?, sender_key: sender_key.to_owned(), signing_key: signing_key.to_owned(), - room_id: room_id.to_owned(), + room_id: room_id.clone(), forwarding_chains: None, }) } @@ -235,7 +237,7 @@ impl InboundGroupSession { pickle_mode: PicklingMode, sender_key: String, signing_key: String, - room_id: String, + room_id: RoomId, ) -> Result { let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?; Ok(InboundGroupSession { diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 9932e060..2ab3bc97 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -21,12 +21,13 @@ use tokio::sync::Mutex; use super::{Account, CryptoStore, InboundGroupSession, Result, Session}; use crate::crypto::device::Device; use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; +use crate::identifiers::{RoomId, UserId}; #[derive(Debug)] pub struct MemoryStore { sessions: SessionStore, inbound_group_sessions: GroupSessionStore, - tracked_users: HashSet, + tracked_users: HashSet, devices: DeviceStore, } @@ -73,7 +74,7 @@ impl CryptoStore for MemoryStore { async fn get_inbound_group_session( &mut self, - room_id: &str, + room_id: &RoomId, sender_key: &str, session_id: &str, ) -> Result>>> { @@ -82,19 +83,19 @@ impl CryptoStore for MemoryStore { .get(room_id, sender_key, session_id)) } - fn tracked_users(&self) -> &HashSet { + fn tracked_users(&self) -> &HashSet { &self.tracked_users } - async fn add_user_for_tracking(&mut self, user: &str) -> Result { - Ok(self.tracked_users.insert(user.to_string())) + async fn add_user_for_tracking(&mut self, user: &UserId) -> Result { + Ok(self.tracked_users.insert(user.clone())) } - async fn get_device(&self, user_id: &str, device_id: &str) -> Result> { + async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result> { Ok(self.devices.get(user_id, device_id)) } - async fn get_user_devices(&self, user_id: &str) -> Result { + async fn get_user_devices(&self, user_id: &UserId) -> Result { Ok(self.devices.user_devices(user_id)) } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index e4367c80..aa485f82 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -26,6 +26,7 @@ use tokio::sync::Mutex; use super::device::Device; use super::memory_stores::UserDevices; use super::olm::{Account, InboundGroupSession, Session}; +use crate::identifiers::{RoomId, UserId}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; pub mod memorystore; @@ -75,13 +76,13 @@ pub trait CryptoStore: Debug + Send + Sync { async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result; async fn get_inbound_group_session( &mut self, - room_id: &str, + room_id: &RoomId, sender_key: &str, session_id: &str, ) -> Result>>>; - fn tracked_users(&self) -> &HashSet; - async fn add_user_for_tracking(&mut self, user: &str) -> Result; + fn tracked_users(&self) -> &HashSet; + async fn add_user_for_tracking(&mut self, user: &UserId) -> Result; async fn save_device(&self, device: Device) -> Result<()>; - async fn get_device(&self, user_id: &str, device_id: &str) -> Result>; - async fn get_user_devices(&self, user_id: &str) -> Result; + async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result>; + async fn get_user_devices(&self, user_id: &UserId) -> Result; } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 3bc3e7cb..8cfcb3a9 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashSet; +use std::convert::TryFrom; use std::path::{Path, PathBuf}; use std::result::Result as StdResult; use std::sync::Arc; @@ -29,6 +30,7 @@ use zeroize::Zeroizing; use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; use crate::crypto::device::Device; use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices}; +use crate::identifiers::{RoomId, UserId}; pub struct SqliteStore { user_id: Arc, @@ -39,14 +41,14 @@ pub struct SqliteStore { inbound_group_sessions: GroupSessionStore, connection: Arc>, pickle_passphrase: Option>, - tracked_users: HashSet, + tracked_users: HashSet, } static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; impl SqliteStore { pub async fn open>( - user_id: &str, + user_id: &UserId, device_id: &str, path: P, ) -> Result { @@ -54,7 +56,7 @@ impl SqliteStore { } pub async fn open_with_passphrase>( - user_id: &str, + user_id: &UserId, device_id: &str, path: P, passphrase: String, @@ -69,7 +71,7 @@ impl SqliteStore { } async fn open_helper>( - user_id: &str, + user_id: &UserId, device_id: &str, path: P, passphrase: Option>, @@ -78,7 +80,7 @@ impl SqliteStore { let connection = SqliteConnection::connect(url.as_ref()).await?; let store = SqliteStore { - user_id: Arc::new(user_id.to_owned()), + user_id: Arc::new(user_id.to_string()), device_id: Arc::new(device_id.to_owned()), account_id: None, sessions: SessionStore::new(), @@ -230,7 +232,7 @@ impl SqliteStore { self.get_pickle_mode(), sender_key.to_string(), signing_key.to_owned(), - room_id.to_owned(), + RoomId::try_from(room_id.as_str()).unwrap(), )?) }) .collect::>>()?) @@ -302,8 +304,8 @@ impl CryptoStore for SqliteStore { device_id = ?2 ", ) - .bind(&*self.user_id) - .bind(&*self.device_id) + .bind(&*self.user_id.to_string()) + .bind(&*self.device_id.to_string()) .bind(&pickle) .bind(acc.shared) .execute(&mut *connection) @@ -311,8 +313,8 @@ impl CryptoStore for SqliteStore { let account_id: (i64,) = query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") - .bind(&*self.user_id) - .bind(&*self.device_id) + .bind(&*self.user_id.to_string()) + .bind(&*self.device_id.to_string()) .fetch_one(&mut *connection) .await?; @@ -383,7 +385,7 @@ impl CryptoStore for SqliteStore { .bind(account_id) .bind(&session.sender_key) .bind(&session.signing_key) - .bind(&session.room_id) + .bind(&session.room_id.to_string()) .bind(&pickle) .execute(&mut *connection) .await?; @@ -393,7 +395,7 @@ impl CryptoStore for SqliteStore { async fn get_inbound_group_session( &mut self, - room_id: &str, + room_id: &RoomId, sender_key: &str, session_id: &str, ) -> Result>>> { @@ -402,19 +404,19 @@ impl CryptoStore for SqliteStore { .get(room_id, sender_key, session_id)) } - fn tracked_users(&self) -> &HashSet { + fn tracked_users(&self) -> &HashSet { &self.tracked_users } - async fn add_user_for_tracking(&mut self, user: &str) -> Result { - Ok(self.tracked_users.insert(user.to_string())) + async fn add_user_for_tracking(&mut self, user: &UserId) -> Result { + Ok(self.tracked_users.insert(user.clone())) } - async fn get_device(&self, _user_id: &str, _device_id: &str) -> Result> { + async fn get_device(&self, _user_id: &UserId, _device_id: &str) -> Result> { todo!() } - async fn get_user_devices(&self, _user_id: &str) -> Result { + async fn get_user_devices(&self, _user_id: &UserId) -> Result { todo!() } @@ -442,7 +444,9 @@ mod test { use tempfile::tempdir; use tokio::sync::Mutex; - use super::{Account, CryptoStore, InboundGroupSession, Session, SqliteStore}; + use super::{ + Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom, UserId, + }; static USER_ID: &str = "@example:localhost"; static DEVICE_ID: &str = "DEVICEID"; @@ -450,7 +454,7 @@ mod test { async fn get_store() -> SqliteStore { let tmpdir = tempdir().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap(); - SqliteStore::open(USER_ID, DEVICE_ID, tmpdir_path) + SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, tmpdir_path) .await .expect("Can't create store") } @@ -501,7 +505,7 @@ mod test { async fn create_store() { let tmpdir = tempdir().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap(); - let _ = SqliteStore::open("@example:localhost", "DEVICEID", tmpdir_path) + let _ = SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), "DEVICEID", tmpdir_path) .await .expect("Can't create store"); } @@ -626,7 +630,7 @@ mod test { let session = InboundGroupSession::new( identity_keys.curve25519(), identity_keys.ed25519(), - "!test:localhost", + &RoomId::try_from("!test:localhost").unwrap(), &outbound_session.session_key(), ) .expect("Can't create session"); @@ -647,7 +651,7 @@ mod test { let session = InboundGroupSession::new( identity_keys.curve25519(), identity_keys.ed25519(), - "!test:localhost", + &RoomId::try_from("!test:localhost").unwrap(), &outbound_session.session_key(), ) .expect("Can't create session"); diff --git a/src/event_emitter/mod.rs b/src/event_emitter/mod.rs index d7a27278..be837c32 100644 --- a/src/event_emitter/mod.rs +++ b/src/event_emitter/mod.rs @@ -65,7 +65,7 @@ use tokio::sync::Mutex; /// } = event.lock().await.deref() /// { /// let rooms = room.lock().await; -/// let member = rooms.members.get(&sender.to_string()).unwrap(); +/// let member = rooms.members.get(&sender).unwrap(); /// println!( /// "{}: {}", /// member