diff --git a/matrix_sdk_crypto/src/device.rs b/matrix_sdk_crypto/src/device.rs index 6c3ed7ce..46dc0ac4 100644 --- a/matrix_sdk_crypto/src/device.rs +++ b/matrix_sdk_crypto/src/device.rs @@ -36,7 +36,7 @@ pub struct Device { user_id: Arc, device_id: Arc, algorithms: Arc>, - keys: Arc>, + keys: Arc>, display_name: Arc>, deleted: Arc, trust_state: Arc>, @@ -75,7 +75,7 @@ impl Device { display_name: Option, trust_state: TrustState, algorithms: Vec, - keys: BTreeMap, + keys: BTreeMap, ) -> Self { Device { user_id: Arc::new(user_id), @@ -105,11 +105,12 @@ impl Device { /// Get the key of the given key algorithm belonging to this device. pub fn get_key(&self, algorithm: KeyAlgorithm) -> Option<&String> { - self.keys.get(&algorithm) + self.keys + .get(&AlgorithmAndDeviceId(algorithm, self.device_id.to_string())) } /// Get a map containing all the device keys. - pub fn keys(&self) -> &BTreeMap { + pub fn keys(&self) -> &BTreeMap { &self.keys } @@ -132,13 +133,6 @@ impl Device { pub(crate) fn update_device(&mut self, device_keys: &DeviceKeys) -> Result<(), SignatureError> { self.verify_device_keys(device_keys)?; - let mut keys = BTreeMap::new(); - - for (key_id, key) in device_keys.keys.iter() { - let key_id = key_id.0; - let _ = keys.insert(key_id, key.clone()); - } - let display_name = Arc::new( device_keys .unsigned @@ -151,7 +145,7 @@ impl Device { &mut self.algorithms, Arc::new(device_keys.algorithms.clone()), ); - let _ = mem::replace(&mut self.keys, Arc::new(keys)); + let _ = mem::replace(&mut self.keys, Arc::new(device_keys.keys.clone())); let _ = mem::replace(&mut self.display_name, display_name); Ok(()) @@ -159,8 +153,7 @@ impl Device { fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> { let signing_key = self - .keys - .get(&KeyAlgorithm::Ed25519) + .get_key(KeyAlgorithm::Ed25519) .ok_or(SignatureError::MissingSigningKey)?; let json_object = json.as_object_mut().ok_or(SignatureError::NotAnObject)?; @@ -232,7 +225,10 @@ impl From<&OlmMachine> for Device { .iter() .map(|(key, value)| { ( - KeyAlgorithm::try_from(key.as_ref()).unwrap(), + AlgorithmAndDeviceId( + KeyAlgorithm::try_from(key.as_ref()).unwrap(), + machine.device_id().clone(), + ), value.to_owned(), ) }) @@ -249,18 +245,11 @@ impl TryFrom<&DeviceKeys> for Device { type Error = SignatureError; fn try_from(device_keys: &DeviceKeys) -> Result { - let mut keys = BTreeMap::new(); - - for (key_id, key) in device_keys.keys.iter() { - let key_id = key_id.0; - let _ = keys.insert(key_id, key.clone()); - } - let device = Device { user_id: Arc::new(device_keys.user_id.clone()), device_id: Arc::new(device_keys.device_id.clone()), algorithms: Arc::new(device_keys.algorithms.clone()), - keys: Arc::new(keys), + keys: Arc::new(device_keys.keys.clone()), display_name: Arc::new( device_keys .unsigned diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index dd9bbb5d..2287afe9 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -29,7 +29,7 @@ use zeroize::Zeroizing; use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; use crate::device::{Device, TrustState}; use crate::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; -use matrix_sdk_common::api::r0::keys::KeyAlgorithm; +use matrix_sdk_common::api::r0::keys::{AlgorithmAndDeviceId, KeyAlgorithm}; use matrix_sdk_common::events::Algorithm; use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId}; @@ -468,7 +468,10 @@ impl SqliteStore { let key = &row.1; - keys.insert(algorithm, key.to_owned()); + keys.insert( + AlgorithmAndDeviceId(algorithm, device_id.clone()), + key.to_owned(), + ); } let device = Device::new( @@ -541,7 +544,7 @@ impl SqliteStore { ", ) .bind(device_row_id) - .bind(key_algorithm.to_string()) + .bind(key_algorithm.0.to_string()) .bind(key) .execute(&mut *connection) .await?;