From 807435c043be6e41f56ca99cf6d76463bd7cb11e Mon Sep 17 00:00:00 2001 From: Devin R Date: Sat, 18 Jul 2020 08:51:19 -0400 Subject: [PATCH] Updates DeviceId to be Box --- matrix_sdk/src/client.rs | 6 +-- matrix_sdk/src/request_builder.rs | 4 +- matrix_sdk_common/Cargo.toml | 2 +- matrix_sdk_crypto/src/device.rs | 14 ++++--- matrix_sdk_crypto/src/error.rs | 10 ++--- matrix_sdk_crypto/src/lib.rs | 2 +- matrix_sdk_crypto/src/machine.rs | 37 +++++++--------- matrix_sdk_crypto/src/memory_stores.rs | 17 +++++--- matrix_sdk_crypto/src/olm.rs | 58 ++++++++++++++------------ matrix_sdk_crypto/src/store/sqlite.rs | 12 +++--- 10 files changed, 85 insertions(+), 77 deletions(-) diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 9b2f213e..8ee4c134 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -472,7 +472,7 @@ impl Client { login_info: login::LoginInfo::Password { password: password.into(), }, - device_id: device_id.map(|d| d.into()), + device_id: device_id.map(|d| d.into().into_boxed_str()), initial_device_display_name: initial_device_display_name.map(|d| d.into()), }; @@ -1407,7 +1407,7 @@ impl Client { #[instrument] async fn claim_one_time_keys( &self, - one_time_keys: BTreeMap>, + one_time_keys: BTreeMap, KeyAlgorithm>>, ) -> Result { let request = claim_keys::Request { timeout: None, @@ -1511,7 +1511,7 @@ impl Client { users_for_query ); - let mut device_keys: BTreeMap> = BTreeMap::new(); + let mut device_keys: BTreeMap>> = BTreeMap::new(); for user in users_for_query.drain() { device_keys.insert(user, Vec::new()); diff --git a/matrix_sdk/src/request_builder.rs b/matrix_sdk/src/request_builder.rs index d397d78f..9d0cf748 100644 --- a/matrix_sdk/src/request_builder.rs +++ b/matrix_sdk/src/request_builder.rs @@ -275,7 +275,7 @@ impl Into for MessagesRequestBuilder { pub struct RegistrationBuilder { password: Option, username: Option, - device_id: Option, + device_id: Option>, initial_device_display_name: Option, auth: Option, kind: Option, @@ -309,7 +309,7 @@ impl RegistrationBuilder { /// /// If this does not correspond to a known client device, a new device will be created. /// The server will auto-generate a device_id if this is not specified. - pub fn device_id>(&mut self, device_id: S) -> &mut Self { + pub fn device_id>>(&mut self, device_id: S) -> &mut Self { self.device_id = Some(device_id.into()); self } diff --git a/matrix_sdk_common/Cargo.toml b/matrix_sdk_common/Cargo.toml index 05b0a9d1..2dd86448 100644 --- a/matrix_sdk_common/Cargo.toml +++ b/matrix_sdk_common/Cargo.toml @@ -17,7 +17,7 @@ js_int = "0.1.8" [dependencies.ruma] git = "https://github.com/ruma/ruma" features = ["client-api"] -rev = "848b225" +rev = "848b22568106d05c5444f3fe46070d5aa16e422b" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] uuid = { version = "0.8.1", features = ["v4"] } diff --git a/matrix_sdk_crypto/src/device.rs b/matrix_sdk_crypto/src/device.rs index e7a5ec14..17804bca 100644 --- a/matrix_sdk_crypto/src/device.rs +++ b/matrix_sdk_crypto/src/device.rs @@ -33,7 +33,7 @@ use crate::verify_json; #[derive(Debug, Clone)] pub struct Device { user_id: Arc, - device_id: Arc, + device_id: Arc>, algorithms: Arc>, keys: Arc>, display_name: Arc>, @@ -70,7 +70,7 @@ impl Device { /// Create a new Device. pub fn new( user_id: UserId, - device_id: DeviceId, + device_id: Box, display_name: Option, trust_state: TrustState, algorithms: Vec, @@ -104,8 +104,10 @@ impl Device { /// Get the key of the given key algorithm belonging to this device. pub fn get_key(&self, algorithm: KeyAlgorithm) -> Option<&String> { - self.keys - .get(&AlgorithmAndDeviceId(algorithm, self.device_id.to_string())) + self.keys.get(&AlgorithmAndDeviceId( + algorithm, + self.device_id.as_ref().clone(), + )) } /// Get a map containing all the device keys. @@ -180,7 +182,7 @@ impl From<&OlmMachine> for Device { fn from(machine: &OlmMachine) -> Self { Device { user_id: Arc::new(machine.user_id().clone()), - device_id: Arc::new(machine.device_id().clone()), + device_id: Arc::new(machine.device_id().into()), algorithms: Arc::new(vec![ Algorithm::MegolmV1AesSha2, Algorithm::OlmV1Curve25519AesSha2, @@ -193,7 +195,7 @@ impl From<&OlmMachine> for Device { ( AlgorithmAndDeviceId( KeyAlgorithm::try_from(key.as_ref()).unwrap(), - machine.device_id().clone(), + machine.device_id().into(), ), value.to_owned(), ) diff --git a/matrix_sdk_crypto/src/error.rs b/matrix_sdk_crypto/src/error.rs index 4dcc207b..5ec901c3 100644 --- a/matrix_sdk_crypto/src/error.rs +++ b/matrix_sdk_crypto/src/error.rs @@ -128,21 +128,21 @@ pub(crate) enum SessionCreationError { "Failed to create a new Olm session for {0} {1}, the requested \ one-time key isn't a signed curve key" )] - OneTimeKeyNotSigned(UserId, DeviceId), + OneTimeKeyNotSigned(UserId, Box), #[error( "Tried to create a new Olm session for {0} {1}, but the signed \ one-time key is missing" )] - OneTimeKeyMissing(UserId, DeviceId), + OneTimeKeyMissing(UserId, Box), #[error("Failed to verify the one-time key signatures for {0} {1}: {2:?}")] - InvalidSignature(UserId, DeviceId, SignatureError), + InvalidSignature(UserId, Box, SignatureError), #[error( "Tried to create an Olm session for {0} {1}, but the device is missing \ a curve25519 key" )] - DeviceMissingCurveKey(UserId, DeviceId), + DeviceMissingCurveKey(UserId, Box), #[error("Error creating new Olm session for {0} {1}: {2:?}")] - OlmError(UserId, DeviceId, OlmSessionError), + OlmError(UserId, Box, OlmSessionError), } impl From for SignatureError { diff --git a/matrix_sdk_crypto/src/lib.rs b/matrix_sdk_crypto/src/lib.rs index cd5556bc..148480ca 100644 --- a/matrix_sdk_crypto/src/lib.rs +++ b/matrix_sdk_crypto/src/lib.rs @@ -82,7 +82,7 @@ pub(crate) fn verify_json( json_object.insert("unsigned".to_string(), u); } - let key_id = AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, key_id.to_string()); + let key_id = AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, key_id.into()); let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?; let signature_object = signatures diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 902cdb39..6bbbce33 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -64,7 +64,7 @@ pub struct OlmMachine { /// The unique user id that owns this account. user_id: UserId, /// The unique device id of the device that holds this account. - device_id: DeviceId, + device_id: Box, /// Our underlying Olm Account holding our identity keys. account: Account, /// Store for the encryption keys. @@ -102,7 +102,7 @@ impl OlmMachine { pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self { OlmMachine { user_id: user_id.clone(), - device_id: device_id.to_owned(), + device_id: device_id.into(), account: Account::new(user_id, &device_id), store: Box::new(MemoryStore::new()), outbound_group_sessions: HashMap::new(), @@ -128,7 +128,7 @@ impl OlmMachine { /// the encryption keys. pub async fn new_with_store( user_id: UserId, - device_id: String, + device_id: Box, mut store: Box, ) -> StoreResult { let account = match store.load_account().await? { @@ -171,7 +171,7 @@ impl OlmMachine { let store = SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?; - OlmMachine::new_with_store(user_id.to_owned(), device_id.to_owned(), Box::new(store)).await + OlmMachine::new_with_store(user_id.to_owned(), device_id.into(), Box::new(store)).await } /// The unique user id that owns this identity. @@ -255,7 +255,7 @@ impl OlmMachine { pub async fn get_missing_sessions( &mut self, users: impl Iterator, - ) -> OlmResult>> { + ) -> OlmResult, KeyAlgorithm>>> { let mut missing = BTreeMap::new(); for user_id in users { @@ -282,10 +282,8 @@ impl OlmMachine { } let user_map = missing.get_mut(user_id).unwrap(); - let _ = user_map.insert( - device.device_id().to_owned(), - KeyAlgorithm::SignedCurve25519, - ); + let _ = + user_map.insert(device.device_id().into(), KeyAlgorithm::SignedCurve25519); } } } @@ -356,7 +354,7 @@ impl OlmMachine { async fn handle_devices_from_key_query( &mut self, - device_keys_map: &BTreeMap>, + device_keys_map: &BTreeMap, DeviceKeys>>, ) -> StoreResult> { let mut changed_devices = Vec::new(); @@ -406,7 +404,8 @@ impl OlmMachine { changed_devices.push(device); } - let current_devices: HashSet<&DeviceId> = device_map.keys().collect(); + let current_devices: HashSet<&DeviceId> = + device_map.keys().map(|id| id.as_ref()).collect(); let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); @@ -843,10 +842,7 @@ impl OlmMachine { let message_type: usize = ciphertext.0.into(); - let ciphertext = CiphertextInfo { - body: ciphertext.1, - message_type: (message_type as u32).into(), - }; + let ciphertext = CiphertextInfo::new(ciphertext.1, (message_type as u32).into()); let mut content = BTreeMap::new(); @@ -855,10 +851,7 @@ impl OlmMachine { self.store.save_sessions(&[session]).await?; Ok(EncryptedEventContent::OlmV1Curve25519AesSha2( - OlmV1Curve25519AesSha2Content { - sender_key: identity_keys.curve25519().to_owned(), - ciphertext: content, - }, + OlmV1Curve25519AesSha2Content::new(content, identity_keys.curve25519().to_owned()), )) } @@ -989,7 +982,7 @@ impl OlmMachine { .await?; user_messages.insert( - DeviceIdOrAllDevices::DeviceId(device.device_id().clone()), + DeviceIdOrAllDevices::DeviceId(device.device_id().into()), serde_json::value::to_raw_value(&encrypted_content)?, ); } @@ -1254,8 +1247,8 @@ mod test { UserId::try_from("@alice:example.org").unwrap() } - fn alice_device_id() -> DeviceId { - "JLAFKJWSCS".to_string() + fn alice_device_id() -> Box { + "JLAFKJWSCS".into() } fn user_id() -> UserId { diff --git a/matrix_sdk_crypto/src/memory_stores.rs b/matrix_sdk_crypto/src/memory_stores.rs index f38f476e..51303864 100644 --- a/matrix_sdk_crypto/src/memory_stores.rs +++ b/matrix_sdk_crypto/src/memory_stores.rs @@ -129,13 +129,13 @@ impl GroupSessionStore { /// In-memory store holding the devices of users. #[derive(Clone, Debug, Default)] pub struct DeviceStore { - entries: Arc>>, + entries: Arc, Device>>>, } /// A read only view over all devices belonging to a user. #[derive(Debug)] pub struct UserDevices { - entries: ReadOnlyView, + entries: ReadOnlyView, Device>, } impl UserDevices { @@ -146,7 +146,7 @@ impl UserDevices { /// Iterator over all the device ids of the user devices. pub fn keys(&self) -> impl Iterator { - self.entries.keys() + self.entries.keys().map(|id| id.as_ref()) } /// Iterator over all the devices of the user devices. @@ -175,7 +175,9 @@ impl DeviceStore { let device_map = self.entries.get_mut(&user_id).unwrap(); device_map - .insert(device.device_id().to_owned(), device) + // TODO this is ok if this is for sure a valid device_id otherwise + // Box::::try_from(&str) is the validated version + .insert(device.device_id().into(), device) .is_none() } @@ -202,7 +204,12 @@ impl DeviceStore { self.entries.insert(user_id.clone(), DashMap::new()); } UserDevices { - entries: self.entries.get(user_id).unwrap().clone().into_read_only(), + entries: self + .entries + .get(user_id) + .map(|d| d.clone()) // TODO I'm sure this is not ok but I'm not sure what to do?? + .unwrap() + .into_read_only(), } } } diff --git a/matrix_sdk_crypto/src/olm.rs b/matrix_sdk_crypto/src/olm.rs index 5c4f6f0d..106445c6 100644 --- a/matrix_sdk_crypto/src/olm.rs +++ b/matrix_sdk_crypto/src/olm.rs @@ -40,7 +40,6 @@ pub use olm_rs::{ utility::OlmUtility, }; -use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId}; use matrix_sdk_common::{ api::r0::keys::{AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey}, events::{ @@ -48,8 +47,9 @@ use matrix_sdk_common::{ encrypted::{EncryptedEventContent, MegolmV1AesSha2Content}, message::MessageEventContent, }, - Algorithm, AnyRoomEventStub, EventJson, EventType, MessageEventStub, + Algorithm, AnySyncRoomEvent, EventJson, EventType, SyncMessageEvent, }, + identifiers::{DeviceId, RoomId, UserId}, }; /// Account holding identity keys for which sessions can be created. @@ -59,7 +59,7 @@ use matrix_sdk_common::{ #[derive(Clone)] pub struct Account { user_id: Arc, - device_id: Arc, + device_id: Arc>, inner: Arc>, identity_keys: Arc, shared: Arc, @@ -94,7 +94,7 @@ impl Account { Account { user_id: Arc::new(user_id.to_owned()), - device_id: Arc::new(device_id.to_owned()), + device_id: Arc::new(device_id.into()), inner: Arc::new(Mutex::new(account)), identity_keys: Arc::new(identity_keys), shared: Arc::new(AtomicBool::new(false)), @@ -267,7 +267,7 @@ impl Account { Ok(Account { user_id: Arc::new(user_id.to_owned()), - device_id: Arc::new(device_id.to_owned()), + device_id: Arc::new(device_id.into()), inner: Arc::new(Mutex::new(account)), identity_keys: Arc::new(identity_keys), shared: Arc::new(AtomicBool::from(shared)), @@ -371,7 +371,7 @@ impl Account { }; one_time_key_map.insert( - AlgorithmAndDeviceId(KeyAlgorithm::SignedCurve25519, key_id.to_owned()), + AlgorithmAndDeviceId(KeyAlgorithm::SignedCurve25519, key_id.as_str().into()), OneTimeKey::SignedKey(signed_key), ); } @@ -431,7 +431,7 @@ impl Account { let one_time_key = key_map.values().next().ok_or_else(|| { SessionCreationError::OneTimeKeyMissing( device.user_id().to_owned(), - device.device_id().to_owned(), + device.device_id().into(), ) })?; @@ -440,7 +440,7 @@ impl Account { OneTimeKey::Key(_) => { return Err(SessionCreationError::OneTimeKeyNotSigned( device.user_id().to_owned(), - device.device_id().to_owned(), + device.device_id().into(), )); } }; @@ -448,7 +448,7 @@ impl Account { device.verify_one_time_key(&one_time_key).map_err(|e| { SessionCreationError::InvalidSignature( device.user_id().to_owned(), - device.device_id().to_owned(), + device.device_id().into(), e, ) })?; @@ -456,7 +456,7 @@ impl Account { let curve_key = device.get_key(KeyAlgorithm::Curve25519).ok_or_else(|| { SessionCreationError::DeviceMissingCurveKey( device.user_id().to_owned(), - device.device_id().to_owned(), + device.device_id().into(), ) })?; @@ -465,7 +465,7 @@ impl Account { .map_err(|e| { SessionCreationError::OlmError( device.user_id().to_owned(), - device.device_id().to_owned(), + device.device_id().into(), e, ) }) @@ -821,8 +821,8 @@ impl InboundGroupSession { /// * `event` - The event that should be decrypted. pub async fn decrypt( &self, - event: &MessageEventStub, - ) -> MegolmResult<(EventJson, u32)> { + event: &SyncMessageEvent, + ) -> MegolmResult<(EventJson, u32)> { let content = match &event.content { EncryptedEventContent::MegolmV1AesSha2(c) => c, _ => return Err(EventError::UnsupportedAlgorithm.into()), @@ -853,7 +853,7 @@ impl InboundGroupSession { ); Ok(( - serde_json::from_value::>(decrypted_value)?, + serde_json::from_value::>(decrypted_value)?, message_index, )) } @@ -882,7 +882,7 @@ impl PartialEq for InboundGroupSession { #[derive(Clone)] pub struct OutboundGroupSession { inner: Arc>, - device_id: Arc, + device_id: Arc>, account_identity_keys: Arc, session_id: Arc, room_id: Arc, @@ -904,7 +904,11 @@ impl OutboundGroupSession { /// session. /// /// * `room_id` - The id of the room that the session is used in. - fn new(device_id: Arc, identity_keys: Arc, room_id: &RoomId) -> Self { + fn new( + device_id: Arc>, + identity_keys: Arc, + room_id: &RoomId, + ) -> Self { let session = OlmOutboundGroupSession::new(); let session_id = session.session_id(); @@ -966,12 +970,14 @@ impl OutboundGroupSession { let ciphertext = self.encrypt_helper(plaintext).await; - EncryptedEventContent::MegolmV1AesSha2(MegolmV1AesSha2Content { - ciphertext, - sender_key: self.account_identity_keys.curve25519().to_owned(), - session_id: self.session_id().to_owned(), - device_id: (&*self.device_id).to_owned(), - }) + EncryptedEventContent::MegolmV1AesSha2(MegolmV1AesSha2Content::new( + matrix_sdk_common::events::room::encrypted::MegolmV1AesSha2ContentInit { + ciphertext, + sender_key: self.account_identity_keys.curve25519().to_owned(), + session_id: self.session_id().to_owned(), + device_id: (&*self.device_id).to_owned(), + }, + )) } /// Check if the session has expired and if it should be rotated. @@ -1044,16 +1050,16 @@ pub(crate) mod test { UserId::try_from("@alice:example.org").unwrap() } - fn alice_device_id() -> DeviceId { - "ALICEDEVICE".to_string() + fn alice_device_id() -> Box { + "ALICEDEVICE".into() } fn bob_id() -> UserId { UserId::try_from("@bob:example.org").unwrap() } - fn bob_device_id() -> DeviceId { - "BOBDEVICE".to_string() + fn bob_device_id() -> Box { + "BOBDEVICE".into() } pub(crate) async fn get_account_and_session() -> (Account, Session) { diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index cfda0a62..44feb951 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -469,14 +469,14 @@ impl SqliteStore { let key = &row.1; keys.insert( - AlgorithmAndDeviceId(algorithm, device_id.clone()), + AlgorithmAndDeviceId(algorithm, device_id.as_str().into()), key.to_owned(), ); } let device = Device::new( user_id, - device_id.to_owned(), + device_id.as_str().into(), display_name.clone(), trust_state, algorithms, @@ -840,16 +840,16 @@ mod test { UserId::try_from("@alice:example.org").unwrap() } - fn alice_device_id() -> DeviceId { - "ALICEDEVICE".to_string() + fn alice_device_id() -> Box { + "ALICEDEVICE".into() } fn bob_id() -> UserId { UserId::try_from("@bob:example.org").unwrap() } - fn bob_device_id() -> DeviceId { - "BOBDEVICE".to_string() + fn bob_device_id() -> Box { + "BOBDEVICE".into() } fn get_account() -> Account {