From b8d6a4c49acf187834c08099411481e3fef4826c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 10 Apr 2020 15:28:43 +0200 Subject: [PATCH 1/8] crypto: Move the account mutex into the account struct. --- src/crypto/machine.rs | 69 +++++++++------------ src/crypto/olm.rs | 104 ++++++++++++++++++++------------ src/crypto/store/memorystore.rs | 2 +- src/crypto/store/mod.rs | 2 +- src/crypto/store/sqlite.rs | 42 ++++++------- 5 files changed, 114 insertions(+), 105 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 1037bc1d..2e95daf2 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -68,7 +68,7 @@ pub struct OlmMachine { /// The unique device id of the device that holds this account. device_id: DeviceId, /// Our underlying Olm Account holding our identity keys. - account: Arc>, + account: Account, /// The number of signed one-time keys we have uploaded to the server. If /// this is None, no action will be taken. After a sync request the client /// needs to set this for us, depending on the count we will suggest the @@ -98,7 +98,7 @@ impl OlmMachine { Ok(OlmMachine { user_id: user_id.clone(), device_id: device_id.to_owned(), - account: Arc::new(Mutex::new(Account::new())), + account: Account::new(), uploaded_signed_key_count: None, store: Box::new(MemoryStore::new()), users_for_key_query: HashSet::new(), @@ -132,7 +132,7 @@ impl OlmMachine { Ok(OlmMachine { user_id: user_id.clone(), device_id: device_id.to_owned(), - account: Arc::new(Mutex::new(account)), + account, uploaded_signed_key_count: None, store: Box::new(store), users_for_key_query: HashSet::new(), @@ -142,7 +142,7 @@ impl OlmMachine { /// Should account or one-time keys be uploaded to the server. pub async fn should_upload_keys(&self) -> bool { - if !self.account.lock().await.shared() { + if !self.account.shared() { return true; } @@ -150,7 +150,7 @@ impl OlmMachine { // max_one_time_Keys() / 2, otherwise tell the client to upload more. match self.uploaded_signed_key_count { Some(count) => { - let max_keys = self.account.lock().await.max_one_time_keys() as u64; + let max_keys = self.account.max_one_time_keys().await as u64; let key_count = (max_keys / 2) - count; key_count > 0 } @@ -169,11 +169,10 @@ impl OlmMachine { &mut self, response: &keys::upload_keys::Response, ) -> Result<()> { - let mut account = self.account.lock().await; - if !account.shared { + if !self.account.shared() { debug!("Marking account as shared"); } - account.shared = true; + self.account.mark_as_shared(); let one_time_key_count = response .one_time_key_counts @@ -187,10 +186,9 @@ impl OlmMachine { ); self.uploaded_signed_key_count = Some(count); - account.mark_keys_as_published(); - drop(account); + self.account.mark_keys_as_published().await; - self.store.save_account(self.account.clone()).await?; + // self.store.save_account(self.account.clone()).await?; Ok(()) } @@ -317,9 +315,8 @@ impl OlmMachine { let session = match self .account - .lock() - .await .create_outbound_session(curve_key, &one_time_key) + .await { Ok(s) => s, Err(e) => { @@ -441,10 +438,9 @@ impl OlmMachine { /// Returns the number of newly generated one-time keys. If no keys can be /// generated returns an empty error. async fn generate_one_time_keys(&self) -> StdResult { - let account = self.account.lock().await; match self.uploaded_signed_key_count { Some(count) => { - let max_keys = account.max_one_time_keys() as u64; + let max_keys = self.account.max_one_time_keys().await as u64; let max_on_server = max_keys / 2; if count >= (max_on_server) { @@ -453,11 +449,11 @@ impl OlmMachine { let key_count = (max_on_server) - count; - let key_count: usize = key_count - .try_into() - .unwrap_or_else(|_| account.max_one_time_keys()); + let max_keys = self.account.max_one_time_keys().await; - account.generate_one_time_keys(key_count); + let key_count: usize = key_count.try_into().unwrap_or(max_keys); + + self.account.generate_one_time_keys(key_count).await; Ok(key_count as u64) } None => Err(()), @@ -466,7 +462,7 @@ impl OlmMachine { /// Sign the device keys and return a JSON Value to upload them. async fn device_keys(&self) -> DeviceKeys { - let identity_keys = self.account.lock().await.identity_keys(); + let identity_keys = self.account.identity_keys(); let mut keys = HashMap::new(); @@ -513,7 +509,7 @@ impl OlmMachine { /// If no one-time keys need to be uploaded returns an empty error. async fn signed_one_time_keys(&self) -> StdResult { let _ = self.generate_one_time_keys().await?; - let one_time_keys = self.account.lock().await.one_time_keys(); + let one_time_keys = self.account.one_time_keys().await; let mut one_time_key_map = HashMap::new(); for (key_id, key) in one_time_keys.curve25519().iter() { @@ -555,10 +551,9 @@ impl OlmMachine { /// * `json` - The value that should be converted into a canonical JSON /// string. async fn sign_json(&self, json: &Value) -> String { - let account = self.account.lock().await; let canonical_json = cjson::to_string(json) .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); - account.sign(&canonical_json) + self.account.sign(&canonical_json).await } /// Verify a signed JSON object. @@ -637,7 +632,7 @@ impl OlmMachine { return Err(()); } - let shared = self.account.lock().await.shared(); + let shared = self.account.shared(); let device_keys = if !shared { Some(self.device_keys().await) @@ -702,8 +697,9 @@ impl OlmMachine { let mut session = match &message { OlmMessage::Message(_) => return Err(OlmError::SessionWedged), OlmMessage::PreKey(m) => { - let account = self.account.lock().await; - account.create_inbound_session(sender_key, m.clone())? + self.account + .create_inbound_session(sender_key, m.clone()) + .await? } }; @@ -740,7 +736,7 @@ impl OlmMachine { return Err(OlmError::UnsupportedAlgorithm); }; - let identity_keys = self.account.lock().await.identity_keys(); + let identity_keys = self.account.identity_keys(); let own_key = identity_keys.curve25519(); let own_ciphertext = content.ciphertext.get(own_key); @@ -799,8 +795,7 @@ impl OlmMachine { async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> Result<()> { let session = OutboundGroupSession::new(room_id); - let account = self.account.lock().await; - let identity_keys = account.identity_keys(); + let identity_keys = self.account.identity_keys(); let sender_key = identity_keys.curve25519(); let signing_key = identity_keys.ed25519(); @@ -855,13 +850,7 @@ impl OlmMachine { Ok(MegolmV1AesSha2Content { algorithm: Algorithm::MegolmV1AesSha2, ciphertext, - sender_key: self - .account - .lock() - .await - .identity_keys() - .curve25519() - .to_owned(), + sender_key: self.account.identity_keys().curve25519().to_owned(), session_id: session.session_id().to_owned(), device_id: self.device_id.to_owned(), }) @@ -874,7 +863,7 @@ impl OlmMachine { event_type: EventType, content: Value, ) -> Result { - let identity_keys = self.account.lock().await.identity_keys(); + let identity_keys = self.account.identity_keys(); let recipient_signing_key = recipient_device .keys(&KeyAlgorithm::Ed25519) @@ -1326,7 +1315,7 @@ mod test { let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let mut device_keys = machine.device_keys().await; - let identity_keys = machine.account.lock().await.identity_keys(); + let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); let ret = machine.verify_json( @@ -1359,7 +1348,7 @@ mod test { machine.uploaded_signed_key_count = Some(49); let mut one_time_keys = machine.signed_one_time_keys().await.unwrap(); - let identity_keys = machine.account.lock().await.identity_keys(); + let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap(); @@ -1378,7 +1367,7 @@ mod test { let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); machine.uploaded_signed_key_count = Some(0); - let identity_keys = machine.account.lock().await.identity_keys(); + let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); let (device_keys, mut one_time_keys) = machine diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 8520b4f1..f9fc99bc 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -33,9 +33,11 @@ use crate::identifiers::RoomId; /// The Olm account. /// An account is the central identity for encrypted communication between two /// devices. It holds the two identity key pairs for a device. +#[derive(Clone)] pub struct Account { - inner: OlmAccount, - pub(crate) shared: bool, + inner: Arc>, + identity_keys: Arc, + pub(crate) shared: Arc, } impl fmt::Debug for Account { @@ -44,7 +46,7 @@ impl fmt::Debug for Account { f, "Olm Account: {:?}, shared: {}", self.identity_keys(), - self.shared + self.shared() ) } } @@ -52,49 +54,61 @@ impl fmt::Debug for Account { impl Account { /// Create a new account. pub fn new() -> Self { + let account = OlmAccount::new(); + let identity_keys = account.parsed_identity_keys(); + Account { - inner: OlmAccount::new(), - shared: false, + inner: Arc::new(Mutex::new(account)), + identity_keys: Arc::new(identity_keys), + shared: Arc::new(AtomicBool::new(false)), } } /// Get the public parts of the identity keys for the account. - pub fn identity_keys(&self) -> IdentityKeys { - self.inner.parsed_identity_keys() + pub fn identity_keys(&self) -> &IdentityKeys { + &self.identity_keys } /// Has the account been shared with the server. pub fn shared(&self) -> bool { - self.shared + self.shared.load(Ordering::Relaxed) + } + + /// Mark the account as shared. + /// + /// Messages shouldn't be encrypted with the session before it has been + /// shared. + pub fn mark_as_shared(&self) { + self.shared.store(true, Ordering::Relaxed); } /// Get the one-time keys of the account. /// /// This can be empty, keys need to be generated first. - pub fn one_time_keys(&self) -> OneTimeKeys { - self.inner.parsed_one_time_keys() + pub async fn one_time_keys(&self) -> OneTimeKeys { + self.inner.lock().await.parsed_one_time_keys() } /// Generate count number of one-time keys. - pub fn generate_one_time_keys(&self, count: usize) { - self.inner.generate_one_time_keys(count); + pub async fn generate_one_time_keys(&self, count: usize) { + self.inner.lock().await.generate_one_time_keys(count); } /// Get the maximum number of one-time keys the account can hold. - pub fn max_one_time_keys(&self) -> usize { - self.inner.max_number_of_one_time_keys() + pub async fn max_one_time_keys(&self) -> usize { + self.inner.lock().await.max_number_of_one_time_keys() } /// Mark the current set of one-time keys as being published. - pub fn mark_keys_as_published(&self) { - self.inner.mark_keys_as_published(); + pub async fn mark_keys_as_published(&self) { + self.inner.lock().await.mark_keys_as_published(); } /// Sign the given string using the accounts signing key. /// /// Returns the signature as a base64 encoded string. - pub fn sign(&self, string: &str) -> String { - self.inner.sign(string) + pub async fn sign(&self, string: &str) -> String { + self.inner.lock().await.sign(string) } /// Store the account as a base64 encoded string. @@ -103,8 +117,8 @@ impl Account { /// /// * `pickle_mode` - The mode that was used to pickle the account, either an /// unencrypted mode or an encrypted using passphrase. - pub fn pickle(&self, pickle_mode: PicklingMode) -> String { - self.inner.pickle(pickle_mode) + pub async fn pickle(&self, pickle_mode: PicklingMode) -> String { + self.inner.lock().await.pickle(pickle_mode) } /// Restore an account from a previously pickled string. @@ -123,8 +137,14 @@ impl Account { pickle_mode: PicklingMode, shared: bool, ) -> Result { - let acc = OlmAccount::unpickle(pickle, pickle_mode)?; - Ok(Account { inner: acc, shared }) + let account = OlmAccount::unpickle(pickle, pickle_mode)?; + let identity_keys = account.parsed_identity_keys(); + + Ok(Account { + inner: Arc::new(Mutex::new(account)), + identity_keys: Arc::new(identity_keys), + shared: Arc::new(AtomicBool::from(shared)), + }) } /// Create a new session with another account given a one-time key. @@ -137,13 +157,15 @@ impl Account { /// /// * `their_one_time_key` - A signed one-time key that the other account /// created and shared with us. - pub fn create_outbound_session( + pub async fn create_outbound_session( &self, their_identity_key: &str, their_one_time_key: &SignedKey, ) -> Result { let session = self .inner + .lock() + .await .create_outbound_session(their_identity_key, &their_one_time_key.key)?; let now = Instant::now(); @@ -166,13 +188,15 @@ impl Account { /// /// * `message` - A pre-key Olm message that was sent to us by the other /// account. - pub fn create_inbound_session( + pub async fn create_inbound_session( &self, their_identity_key: &str, message: PreKeyMessage, ) -> Result { let session = self .inner + .lock() + .await .create_inbound_session_from(their_identity_key, message)?; let now = Instant::now(); @@ -188,7 +212,7 @@ impl Account { impl PartialEq for Account { fn eq(&self, other: &Self) -> bool { - self.identity_keys() == other.identity_keys() && self.shared == other.shared + self.identity_keys() == other.identity_keys() && self.shared() == other.shared() } } @@ -566,16 +590,16 @@ mod test { assert!(!identyty_keys.curve25519().is_empty()); } - #[test] - fn one_time_keys_creation() { + #[tokio::test] + async fn one_time_keys_creation() { let account = Account::new(); - let one_time_keys = account.one_time_keys(); + let one_time_keys = account.one_time_keys().await; assert!(one_time_keys.curve25519().is_empty()); - assert_ne!(account.max_one_time_keys(), 0); + assert_ne!(account.max_one_time_keys().await, 0); - account.generate_one_time_keys(10); - let one_time_keys = account.one_time_keys(); + account.generate_one_time_keys(10).await; + let one_time_keys = account.one_time_keys().await; assert!(!one_time_keys.curve25519().is_empty()); assert_ne!(one_time_keys.values().len(), 0); @@ -588,21 +612,19 @@ mod test { one_time_keys.get("curve25519").unwrap() ); - account.mark_keys_as_published(); - let one_time_keys = account.one_time_keys(); + account.mark_keys_as_published().await; + let one_time_keys = account.one_time_keys().await; assert!(one_time_keys.curve25519().is_empty()); } - #[test] - fn session_creation() { + #[tokio::test] + async fn session_creation() { let alice = Account::new(); let bob = Account::new(); let alice_keys = alice.identity_keys(); - let one_time_keys = alice.one_time_keys(); - - alice.generate_one_time_keys(1); - let one_time_keys = alice.one_time_keys(); - alice.mark_keys_as_published(); + alice.generate_one_time_keys(1).await; + let one_time_keys = alice.one_time_keys().await; + alice.mark_keys_as_published().await; let one_time_key = one_time_keys .curve25519() @@ -619,6 +641,7 @@ mod test { let mut bob_session = bob .create_outbound_session(alice_keys.curve25519(), &one_time_key) + .await .unwrap(); let plaintext = "Hello world"; @@ -633,6 +656,7 @@ mod test { let bob_keys = bob.identity_keys(); let mut alice_session = alice .create_inbound_session(bob_keys.curve25519(), prekey_message) + .await .unwrap(); assert_eq!(bob_session.session_id(), alice_session.session_id()); diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 2ab3bc97..23116b67 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -48,7 +48,7 @@ impl CryptoStore for MemoryStore { Ok(None) } - async fn save_account(&mut self, _: Arc>) -> Result<()> { + async fn save_account(&mut self, _: Account) -> Result<()> { Ok(()) } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index aa485f82..1b41f63b 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -66,7 +66,7 @@ pub type Result = std::result::Result; #[async_trait] pub trait CryptoStore: Debug + Send + Sync { async fn load_account(&mut self) -> Result>; - async fn save_account(&mut self, account: Arc>) -> Result<()>; + async fn save_account(&mut self, account: Account) -> Result<()>; async fn save_session(&mut self, session: Arc>) -> Result<()>; async fn add_and_save_session(&mut self, session: Session) -> Result<()>; async fn get_sessions( diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 32742326..a629214c 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -288,9 +288,8 @@ impl CryptoStore for SqliteStore { Ok(result) } - async fn save_account(&mut self, account: Arc>) -> Result<()> { - let acc = account.lock().await; - let pickle = acc.pickle(self.get_pickle_mode()); + async fn save_account(&mut self, account: Account) -> Result<()> { + let pickle = account.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; query( @@ -307,7 +306,7 @@ impl CryptoStore for SqliteStore { .bind(&*self.user_id.to_string()) .bind(&*self.device_id.to_string()) .bind(&pickle) - .bind(acc.shared) + .bind(account.shared()) .execute(&mut *connection) .await?; @@ -460,7 +459,7 @@ mod test { .expect("Can't create store") } - async fn get_loaded_store() -> (Arc>, SqliteStore) { + async fn get_loaded_store() -> (Account, SqliteStore) { let mut store = get_store().await; let account = get_account(); store @@ -471,19 +470,19 @@ mod test { (account, store) } - fn get_account() -> Arc> { - let account = Account::new(); - Arc::new(Mutex::new(account)) + fn get_account() -> Account { + Account::new() } - fn get_account_and_session() -> (Arc>, Session) { + async fn get_account_and_session() -> (Account, Session) { let alice = Account::new(); let bob = Account::new(); - bob.generate_one_time_keys(1); + bob.generate_one_time_keys(1).await; let one_time_key = bob .one_time_keys() + .await .curve25519() .iter() .nth(0) @@ -497,9 +496,10 @@ mod test { let sender_key = bob.identity_keys().curve25519().to_owned(); let session = alice .create_outbound_session(&sender_key, &one_time_key) + .await .unwrap(); - (Arc::new(Mutex::new(alice)), session) + (alice, session) } #[tokio::test] @@ -532,11 +532,10 @@ mod test { .await .expect("Can't save account"); - let acc = account.lock().await; let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); - assert_eq!(*acc, loaded_account); + assert_eq!(account, loaded_account); } #[tokio::test] @@ -549,7 +548,7 @@ mod test { .await .expect("Can't save account"); - account.lock().await.shared = true; + account.mark_as_shared(); store .save_account(account.clone()) @@ -558,15 +557,14 @@ mod test { let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); - let acc = account.lock().await; - assert_eq!(*acc, loaded_account); + assert_eq!(account, loaded_account); } #[tokio::test] async fn save_session() { let mut store = get_store().await; - let (account, session) = get_account_and_session(); + let (account, session) = get_account_and_session().await; let session = Arc::new(Mutex::new(session)); assert!(store.save_session(session.clone()).await.is_err()); @@ -582,7 +580,7 @@ mod test { #[tokio::test] async fn load_sessions() { let mut store = get_store().await; - let (account, session) = get_account_and_session(); + let (account, session) = get_account_and_session().await; let session = Arc::new(Mutex::new(session)); store .save_account(account.clone()) @@ -604,7 +602,7 @@ mod test { #[tokio::test] async fn add_and_save_session() { let mut store = get_store().await; - let (account, session) = get_account_and_session(); + let (account, session) = get_account_and_session().await; let sender_key = session.sender_key.to_owned(); let session_id = session.session_id(); @@ -625,8 +623,7 @@ mod test { async fn save_inbound_group_session() { let (account, mut store) = get_loaded_store().await; - let acc = account.lock().await; - let identity_keys = acc.identity_keys(); + let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); let session = InboundGroupSession::new( identity_keys.curve25519(), @@ -646,8 +643,7 @@ mod test { async fn load_inbound_group_session() { let (account, mut store) = get_loaded_store().await; - let acc = account.lock().await; - let identity_keys = acc.identity_keys(); + let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); let session = InboundGroupSession::new( identity_keys.curve25519(), From a4d41378d4e0935310f452971712c45fd9f1966e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 10 Apr 2020 16:08:47 +0200 Subject: [PATCH 2/8] crypto: Move the inbound group session lock into the session struct. --- src/crypto/machine.rs | 2 +- src/crypto/memory_stores.rs | 17 +++++----- src/crypto/olm.rs | 55 +++++++++++++++++++-------------- src/crypto/store/memorystore.rs | 2 +- src/crypto/store/mod.rs | 6 +++- src/crypto/store/sqlite.rs | 12 +++---- 6 files changed, 54 insertions(+), 40 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 2e95daf2..1ebc5526 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -1139,7 +1139,7 @@ impl OlmMachine { // TODO check if the olm session is wedged and re-request the key. let session = session.ok_or(OlmError::MissingSession)?; - let (plaintext, _) = session.lock().await.decrypt(content.ciphertext.clone())?; + let (plaintext, _) = session.decrypt(content.ciphertext.clone()).await?; // TODO check the message index. // TODO check if this is from a verified device. diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index de2e67e3..89704ad1 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -60,7 +60,7 @@ impl SessionStore { #[derive(Debug)] pub struct GroupSessionStore { - entries: HashMap>>>>, + entries: HashMap>>, } impl GroupSessionStore { @@ -72,18 +72,19 @@ impl GroupSessionStore { pub fn add(&mut self, session: InboundGroupSession) -> bool { if !self.entries.contains_key(&session.room_id) { - self.entries - .insert(session.room_id.to_owned(), HashMap::new()); + let room_id = &*session.room_id; + self.entries.insert(room_id.clone(), HashMap::new()); } let room_map = self.entries.get_mut(&session.room_id).unwrap(); - if !room_map.contains_key(&session.sender_key) { - room_map.insert(session.sender_key.to_owned(), HashMap::new()); + if !room_map.contains_key(&*session.sender_key) { + let sender_key = &*session.sender_key; + room_map.insert(sender_key.to_owned(), HashMap::new()); } - let sender_map = room_map.get_mut(&session.sender_key).unwrap(); - let ret = sender_map.insert(session.session_id(), Arc::new(Mutex::new(session))); + let sender_map = room_map.get_mut(&*session.sender_key).unwrap(); + let ret = sender_map.insert(session.session_id().to_owned(), session); ret.is_some() } @@ -93,7 +94,7 @@ impl GroupSessionStore { room_id: &RoomId, sender_key: &str, session_id: &str, - ) -> Option>> { + ) -> Option { self.entries .get(room_id) .and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned())) diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index f9fc99bc..a615fa0b 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -343,12 +343,14 @@ pub struct GroupSessionKey(pub String); /// /// Inbound group sessions are used to exchange room messages between a group of /// participants. Inbound group sessions are used to decrypt the room messages. +#[derive(Clone)] pub struct InboundGroupSession { - inner: OlmInboundGroupSession, - pub(crate) sender_key: String, - pub(crate) signing_key: String, - pub(crate) room_id: RoomId, - forwarding_chains: Option>, + inner: Arc>, + session_id: Arc, + pub(crate) sender_key: Arc, + pub(crate) signing_key: Arc, + pub(crate) room_id: Arc, + forwarding_chains: Arc>>>, } impl InboundGroupSession { @@ -374,12 +376,16 @@ impl InboundGroupSession { room_id: &RoomId, session_key: GroupSessionKey, ) -> Result { + let session = OlmInboundGroupSession::new(&session_key.0)?; + let session_id = session.session_id(); + Ok(InboundGroupSession { - inner: OlmInboundGroupSession::new(&session_key.0)?, - sender_key: sender_key.to_owned(), - signing_key: signing_key.to_owned(), - room_id: room_id.clone(), - forwarding_chains: None, + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(session_id), + sender_key: Arc::new(sender_key.to_owned()), + signing_key: Arc::new(signing_key.to_owned()), + room_id: Arc::new(room_id.clone()), + forwarding_chains: Arc::new(Mutex::new(None)), }) } @@ -389,8 +395,8 @@ impl InboundGroupSession { /// /// * `pickle_mode` - The mode that was used to pickle the group session, /// either an unencrypted mode or an encrypted using passphrase. - pub fn pickle(&self, pickle_mode: PicklingMode) -> String { - self.inner.pickle(pickle_mode) + pub async fn pickle(&self, pickle_mode: PicklingMode) -> String { + self.inner.lock().await.pickle(pickle_mode) } /// Restore a Session from a previously pickled string. @@ -420,23 +426,26 @@ impl InboundGroupSession { room_id: RoomId, ) -> Result { let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?; + let session_id = session.session_id(); + Ok(InboundGroupSession { - inner: session, - sender_key, - signing_key, - room_id, - forwarding_chains: None, + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(session_id), + sender_key: Arc::new(sender_key), + signing_key: Arc::new(signing_key), + room_id: Arc::new(room_id), + forwarding_chains: Arc::new(Mutex::new(None)), }) } /// Returns the unique identifier for this session. - pub fn session_id(&self) -> String { - self.inner.session_id() + pub fn session_id(&self) -> &str { + &self.session_id } /// Get the first message index we know how to decrypt. - pub fn first_known_index(&self) -> u32 { - self.inner.first_known_index() + pub async fn first_known_index(&self) -> u32 { + self.inner.lock().await.first_known_index() } /// Decrypt the given ciphertext. @@ -447,8 +456,8 @@ impl InboundGroupSession { /// # Arguments /// /// * `message` - The message that should be decrypted. - pub fn decrypt(&self, message: String) -> Result<(String, u32), OlmGroupSessionError> { - self.inner.decrypt(message) + pub async fn decrypt(&self, message: String) -> Result<(String, u32), OlmGroupSessionError> { + self.inner.lock().await.decrypt(message) } } diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 23116b67..d14e29b7 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -77,7 +77,7 @@ impl CryptoStore for MemoryStore { room_id: &RoomId, sender_key: &str, session_id: &str, - ) -> Result>>> { + ) -> Result> { Ok(self .inbound_group_sessions .get(room_id, sender_key, session_id)) diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index 1b41f63b..016cb334 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -67,21 +67,25 @@ pub type Result = std::result::Result; pub trait CryptoStore: Debug + Send + Sync { async fn load_account(&mut self) -> Result>; async fn save_account(&mut self, account: Account) -> Result<()>; + async fn save_session(&mut self, session: Arc>) -> Result<()>; async fn add_and_save_session(&mut self, session: Session) -> Result<()>; async fn get_sessions( &mut self, sender_key: &str, ) -> Result>>>>>>; + async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result; async fn get_inbound_group_session( &mut self, room_id: &RoomId, sender_key: &str, session_id: &str, - ) -> Result>>>; + ) -> Result>; + fn tracked_users(&self) -> &HashSet; async fn add_user_for_tracking(&mut self, user: &UserId) -> Result; + async fn save_device(&self, device: Device) -> Result<()>; async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result>; async fn get_user_devices(&self, user_id: &UserId) -> Result; diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index a629214c..1b57a832 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -366,7 +366,7 @@ impl CryptoStore for SqliteStore { async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result { let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; - let pickle = session.pickle(self.get_pickle_mode()); + let pickle = session.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; let session_id = session.session_id(); @@ -382,9 +382,9 @@ impl CryptoStore for SqliteStore { ) .bind(session_id) .bind(account_id) - .bind(&session.sender_key) - .bind(&session.signing_key) - .bind(&session.room_id.to_string()) + .bind(&*session.sender_key) + .bind(&*session.signing_key) + .bind(&*session.room_id.to_string()) .bind(&pickle) .execute(&mut *connection) .await?; @@ -397,7 +397,7 @@ impl CryptoStore for SqliteStore { room_id: &RoomId, sender_key: &str, session_id: &str, - ) -> Result>>> { + ) -> Result> { Ok(self .inbound_group_sessions .get(room_id, sender_key, session_id)) @@ -653,7 +653,7 @@ mod test { ) .expect("Can't create session"); - let session_id = session.session_id(); + let session_id = session.session_id().to_owned(); store .save_inbound_group_session(session) From c282d9fabc1ef102447efaf5da478b1a0a727cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 10 Apr 2020 16:17:31 +0200 Subject: [PATCH 3/8] machine: Uncomment account saving after keys were published. --- src/crypto/machine.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 1ebc5526..0a717857 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -187,8 +187,7 @@ impl OlmMachine { self.uploaded_signed_key_count = Some(count); self.account.mark_keys_as_published().await; - - // self.store.save_account(self.account.clone()).await?; + self.store.save_account(self.account.clone()).await?; Ok(()) } From 7577ddfc003d56b663115f365e0d862fcd9cf720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 10 Apr 2020 16:18:29 +0200 Subject: [PATCH 4/8] crypto: Remove one-time keys after a inbound session was created successfully. --- src/crypto/olm.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index a615fa0b..e4cb1f1e 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -199,6 +199,14 @@ impl Account { .await .create_inbound_session_from(their_identity_key, message)?; + self.inner + .lock() + .await + .remove_one_time_keys(&session) + .expect( + "Session was successfully created but the account doesn't hold a matching one-time key", + ); + let now = Instant::now(); Ok(Session { From 01656690bcde75f17b9905b5cc92f7e93292605a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 10 Apr 2020 16:18:55 +0200 Subject: [PATCH 5/8] crypto: Save the account after an inbound session was created. --- src/crypto/machine.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 0a717857..28dd352e 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -696,9 +696,12 @@ impl OlmMachine { let mut session = match &message { OlmMessage::Message(_) => return Err(OlmError::SessionWedged), OlmMessage::PreKey(m) => { - self.account + let session = self + .account .create_inbound_session(sender_key, m.clone()) - .await? + .await?; + self.store.save_account(self.account.clone()).await?; + session } }; From 8210c2377dbd5bc15003eff3983ddedd6cdfe08e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 10 Apr 2020 17:02:30 +0200 Subject: [PATCH 6/8] crypto: Take the session key out of the RoomKey event. --- src/crypto/machine.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 28dd352e..d3fe2972 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -14,6 +14,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryInto; +use std::mem; #[cfg(feature = "sqlite-cryptostore")] use std::path::Path; use std::result::Result as StdResult; @@ -751,11 +752,11 @@ impl OlmMachine { OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone()) .map_err(|_| OlmError::UnsupportedOlmType)?; - let decrypted_event = self + let mut decrypted_event = self .decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message) .await?; debug!("Decrypted a to-device event {:?}", decrypted_event); - self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event) + self.handle_decrypted_to_device_event(&content.sender_key, &mut decrypted_event) .await?; Ok(decrypted_event) @@ -765,7 +766,7 @@ impl OlmMachine { } } - async fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> { + async fn add_room_key(&mut self, sender_key: &str, event: &mut ToDeviceRoomKey) -> Result<()> { match event.content.algorithm { Algorithm::MegolmV1AesSha2 => { // TODO check for all the valid fields. @@ -774,7 +775,7 @@ impl OlmMachine { .get("ed25519") .ok_or(OlmError::MissingSigningKey)?; - let session_key = GroupSessionKey(event.content.session_key.to_owned()); + let session_key = GroupSessionKey(mem::take(&mut event.content.session_key)); let session = InboundGroupSession::new( sender_key, @@ -1038,7 +1039,7 @@ impl OlmMachine { async fn handle_decrypted_to_device_event( &mut self, sender_key: &str, - event: &EventResult, + event: &mut EventResult, ) -> Result<()> { let event = if let EventResult::Ok(e) = event { e From cb8f1c1a5bdc66f2cd951a8d66cf3564782669f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 10 Apr 2020 17:02:51 +0200 Subject: [PATCH 7/8] crypto: Zeroize the GroupSessionKey struct. --- Cargo.toml | 4 ++-- src/crypto/olm.rs | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8db93f37..0d76e75a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ version = "0.1.0" [features] default = [] -encryption = ["olm-rs", "serde/derive", "serde_json", "cjson"] +encryption = ["olm-rs", "serde/derive", "serde_json", "cjson", "zeroize"] sqlite-cryptostore = ["sqlx", "zeroize"] [dependencies] @@ -35,7 +35,7 @@ olm-rs = { git = "https://gitlab.gnome.org/poljar/olm-rs", optional = true, feat serde = { version = "1.0.106", optional = true, features = ["derive"] } serde_json = { version = "1.0.51", optional = true } cjson = { version = "0.1.0", optional = true } -zeroize = { version = "1.1.0", optional = true } +zeroize = { version = "1.1.0", optional = true, features = ["zeroize_derive"] } # Misc dependencies thiserror = "1.0.14" diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index e4cb1f1e..bf5e69ea 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -19,6 +19,7 @@ use std::time::Instant; use serde::Serialize; use tokio::sync::Mutex; +use zeroize::Zeroize; use olm_rs::account::{IdentityKeys, OlmAccount, OneTimeKeys}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; @@ -344,7 +345,8 @@ impl PartialEq for Session { /// The private session key of a group session. /// Can be used to create a new inbound group session. -#[derive(Clone, Serialize)] +#[derive(Clone, Serialize, Zeroize)] +#[zeroize(drop)] pub struct GroupSessionKey(pub String); /// Inbound group session. From 5f6cbbb193a8395d8062c7527fdc5b2de9e955a2 Mon Sep 17 00:00:00 2001 From: Caleb Bassi Date: Sat, 11 Apr 2020 12:01:29 -0700 Subject: [PATCH 8/8] fix matrix badge link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9737dfb3..51986c22 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ [![Build Status](https://img.shields.io/travis/matrix-org/matrix-rust-sdk.svg?style=flat-square)](https://travis-ci.org/matrix-org/matrix-rust-sdk) [![codecov](https://img.shields.io/codecov/c/github/matrix-org/matrix-rust-sdk/master.svg?style=flat-square)](https://codecov.io/gh/matrix-org/matrix-rust-sdk) [![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg?style=flat-square)](https://opensource.org/licenses/Apache-2.0) -[![#matrix-rust-sdk](https://img.shields.io/badge/matrix-%23matrix--rust--sdk-blue?style=flat-square)](https://matrix.to/#/!iYnZafYUoXkeVPOSQh:matrix.org?via=matrix.org&via=matrix.ffslfl.net&via=raim.ist) +[![#matrix-rust-sdk](https://img.shields.io/badge/matrix-%23matrix--rust--sdk-blue?style=flat-square)](https://matrix.to/#/#matrix-rust-sdk:matrix.org) # matrix-rust-sdk