diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index 287e8b96..04a73ea7 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -46,6 +46,55 @@ use crate::{ /// panic once we try to pickle a Signing object. const DEFAULT_PICKLE: &str = "DEFAULT_PICKLE_PASSPHRASE_123456"; +trait EncodeKey { + const SEPARATOR: u8 = 0xff; + fn encode(&self) -> Vec; +} + +impl EncodeKey for &UserId { + fn encode(&self) -> Vec { + self.as_str().encode() + } +} + +impl EncodeKey for &RoomId { + fn encode(&self) -> Vec { + self.as_str().encode() + } +} + +impl EncodeKey for &str { + fn encode(&self) -> Vec { + [self.as_bytes(), &[Self::SEPARATOR]].concat() + } +} + +impl EncodeKey for (&str, &str) { + fn encode(&self) -> Vec { + [ + self.0.as_bytes(), + &[Self::SEPARATOR], + self.1.as_bytes(), + &[Self::SEPARATOR], + ] + .concat() + } +} + +impl EncodeKey for (&str, &str, &str) { + fn encode(&self) -> Vec { + [ + self.0.as_bytes(), + &[Self::SEPARATOR], + self.1.as_bytes(), + &[Self::SEPARATOR], + self.2.as_bytes(), + &[Self::SEPARATOR], + ] + .concat() + } +} + /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug, Clone)] pub struct SledStore { @@ -143,7 +192,7 @@ impl SledStore { fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result { let key = if let Some(key) = database - .get("pickle_key")? + .get("pickle_key".encode())? .map(|v| serde_json::from_slice(&v)) { PickleKey::from_encrypted(passphrase, key?) @@ -151,7 +200,7 @@ impl SledStore { } else { let key = PickleKey::new(); let encrypted = key.encrypt(passphrase); - database.insert("pickle_key", serde_json::to_vec(&encrypted)?)?; + database.insert("pickle_key".encode(), serde_json::to_vec(&encrypted)?)?; key }; @@ -195,7 +244,7 @@ impl SledStore { let identity_keys = account.identity_keys; self.outbound_group_sessions - .get(room_id.as_str())? + .get(room_id.encode())? .map(|p| serde_json::from_slice(&p).map_err(CryptoStoreError::Serialization)) .transpose()? .map(|p| { @@ -231,7 +280,7 @@ impl SledStore { let session_id = session.session_id(); let pickle = session.pickle(self.get_pickle_mode()).await; - let key = format!("{}{}", sender_key, session_id); + let key = (sender_key, session_id).encode(); self.session_cache.add(session).await; session_changes.insert(key, pickle); @@ -243,7 +292,7 @@ impl SledStore { let room_id = session.room_id(); let sender_key = session.sender_key(); let session_id = session.session_id(); - let key = format!("{}{}{}", room_id, sender_key, session_id); + let key = (room_id.as_str(), sender_key, session_id).encode(); let pickle = session.pickle(self.get_pickle_mode()).await; inbound_session_changes.insert(key, pickle); @@ -284,33 +333,33 @@ impl SledStore { )| { if let Some(a) = &account_pickle { account.insert( - "account", + "account".encode(), serde_json::to_vec(a).map_err(ConflictableTransactionError::Abort)?, )?; } if let Some(i) = &private_identity_pickle { private_identity.insert( - "identity", + "identity".encode(), serde_json::to_vec(&i).map_err(ConflictableTransactionError::Abort)?, )?; } for device in device_changes.new.iter().chain(&device_changes.changed) { - let key = format!("{}{}", device.user_id(), device.device_id()); + let key = (device.user_id().as_str(), device.device_id().as_str()).encode(); let device = serde_json::to_vec(&device) .map_err(ConflictableTransactionError::Abort)?; - devices.insert(key.as_str(), device)?; + devices.insert(key, device)?; } for device in &device_changes.deleted { - let key = format!("{}{}", device.user_id(), device.device_id()); - devices.remove(key.as_str())?; + let key = (device.user_id().as_str(), device.device_id().as_str()).encode(); + devices.remove(key)?; } for identity in identity_changes.changed.iter().chain(&identity_changes.new) { identities.insert( - identity.user_id().as_str(), + identity.user_id().encode(), serde_json::to_vec(&identity) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -318,7 +367,7 @@ impl SledStore { for (key, session) in &session_changes { sessions.insert( - key.as_str(), + key.as_slice(), serde_json::to_vec(&session) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -326,7 +375,7 @@ impl SledStore { for (key, session) in &inbound_session_changes { inbound_sessions.insert( - key.as_str(), + key.as_slice(), serde_json::to_vec(&session) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -334,7 +383,7 @@ impl SledStore { for (key, session) in &outbound_session_changes { outbound_sessions.insert( - key.as_str(), + key.encode(), serde_json::to_vec(&session) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -362,7 +411,7 @@ impl SledStore { #[async_trait] impl CryptoStore for SledStore { async fn load_account(&self) -> Result> { - if let Some(pickle) = self.account.get("account")? { + if let Some(pickle) = self.account.get("account".encode())? { let pickle = serde_json::from_slice(&pickle)?; self.load_tracked_users().await?; @@ -379,7 +428,7 @@ impl CryptoStore for SledStore { async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { let pickle = account.pickle(self.get_pickle_mode()).await; self.account - .insert("account", serde_json::to_vec(&pickle)?)?; + .insert("account".encode(), serde_json::to_vec(&pickle)?)?; Ok(()) } @@ -397,7 +446,7 @@ impl CryptoStore for SledStore { if self.session_cache.get(sender_key).is_none() { let sessions: Result> = self .sessions - .scan_prefix(sender_key) + .scan_prefix(sender_key.encode()) .map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization)) .map(|p| { Session::from_pickle( @@ -423,7 +472,7 @@ impl CryptoStore for SledStore { sender_key: &str, session_id: &str, ) -> Result> { - let key = format!("{}{}{}", room_id, sender_key, session_id); + let key = (room_id.as_str(), sender_key, session_id).encode(); let pickle = self .inbound_group_sessions .get(&key)? @@ -487,7 +536,7 @@ impl CryptoStore for SledStore { user_id: &UserId, device_id: &DeviceId, ) -> Result> { - let key = format!("{}{}", user_id, device_id); + let key = (user_id.as_str(), device_id.as_str()).encode(); if let Some(d) = self.devices.get(key)? { Ok(Some(serde_json::from_slice(&d)?)) @@ -501,7 +550,7 @@ impl CryptoStore for SledStore { user_id: &UserId, ) -> Result> { self.devices - .scan_prefix(user_id.as_str()) + .scan_prefix(user_id.encode()) .map(|d| serde_json::from_slice(&d?.1).map_err(CryptoStoreError::Serialization)) .map(|d| { let d: ReadOnlyDevice = d?; @@ -513,13 +562,13 @@ impl CryptoStore for SledStore { async fn get_user_identity(&self, user_id: &UserId) -> Result> { Ok(self .identities - .get(user_id.as_str())? + .get(user_id.encode())? .map(|i| serde_json::from_slice(&i)) .transpose()?) } async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key.as_str(), value.as_str())?; + self.values.insert(key.as_str().encode(), value.as_str())?; Ok(()) } @@ -531,12 +580,12 @@ impl CryptoStore for SledStore { async fn get_value(&self, key: &str) -> Result> { Ok(self .values - .get(key)? + .get(key.encode())? .map(|v| String::from_utf8_lossy(&v).to_string())) } async fn load_identity(&self) -> Result> { - if let Some(i) = self.private_identity.get("identity")? { + if let Some(i) = self.private_identity.get("identity".encode())? { let pickle = serde_json::from_slice(&i)?; Ok(Some( PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key())