diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 07d8937d..6f50c1e8 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1344,6 +1344,8 @@ impl OlmMachine { /// # Arguments /// /// * `event` - The event that should be decrypted. + /// + /// * `room_id` - The ID of the room where the event was sent to. pub async fn decrypt_room_event( &mut self, event: &MessageEventStub, @@ -1361,35 +1363,10 @@ impl OlmMachine { // TODO check if the Olm session is wedged and re-request the key. let session = session.ok_or(MegolmError::MissingSession)?; - let (plaintext, _) = session.decrypt(content.ciphertext.clone()).await?; // TODO check the message index. // TODO check if this is from a verified device. + let (decrypted_event, _) = session.decrypt(event).await?; - // TODO move this logic into the group session. - let mut decrypted_value = serde_json::from_str::(&plaintext)?; - let decrypted_object = decrypted_value - .as_object_mut() - .ok_or(EventError::NotAnObject)?; - - // TODO better number conversion here. - let server_ts = event - .origin_server_ts - .duration_since(std::time::SystemTime::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - let server_ts: i64 = server_ts.try_into().unwrap_or_default(); - - decrypted_object.insert("sender".to_owned(), event.sender.to_string().into()); - decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into()); - decrypted_object.insert("origin_server_ts".to_owned(), server_ts.into()); - - decrypted_object.insert( - "unsigned".to_owned(), - serde_json::to_value(&event.unsigned).unwrap_or_default(), - ); - - let decrypted_event = - serde_json::from_value::>(decrypted_value)?; trace!("Successfully decrypted Megolm event {:?}", decrypted_event); // TODO set the encryption info on the event (is it verified, was it // decrypted, sender key...) diff --git a/matrix_sdk_crypto/src/olm.rs b/matrix_sdk_crypto/src/olm.rs index e575ca4d..5bc220b9 100644 --- a/matrix_sdk_crypto/src/olm.rs +++ b/matrix_sdk_crypto/src/olm.rs @@ -13,6 +13,7 @@ // limitations under the License. use matrix_sdk_common::instant::Instant; +use std::convert::TryInto; use std::fmt; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; @@ -31,6 +32,7 @@ use olm_rs::outbound_group_session::OlmOutboundGroupSession; use olm_rs::session::OlmSession; use olm_rs::PicklingMode; +use crate::error::{EventError, MegolmResult}; pub use olm_rs::{ session::{OlmMessage, PreKeyMessage}, utility::OlmUtility, @@ -44,7 +46,7 @@ use matrix_sdk_common::{ encrypted::{EncryptedEventContent, MegolmV1AesSha2Content}, message::MessageEventContent, }, - Algorithm, EventType, + Algorithm, AnyRoomEventStub, EventJson, EventType, MessageEventStub, }, }; @@ -642,9 +644,56 @@ impl InboundGroupSession { /// # Arguments /// /// * `message` - The message that should be decrypted. - pub async fn decrypt(&self, message: String) -> Result<(String, u32), OlmGroupSessionError> { + pub async fn decrypt_helper( + &self, + message: String, + ) -> Result<(String, u32), OlmGroupSessionError> { self.inner.lock().await.decrypt(message) } + + /// Decrypt an event from a room timeline. + /// + /// # Arguments + /// + /// * `event` - The event that should be decrypted. + pub async fn decrypt( + &self, + event: &MessageEventStub, + ) -> MegolmResult<(EventJson, u32)> { + let content = match &event.content { + EncryptedEventContent::MegolmV1AesSha2(c) => c, + _ => return Err(EventError::UnsupportedAlgorithm.into()), + }; + + let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?; + + let mut decrypted_value = serde_json::from_str::(&plaintext)?; + let decrypted_object = decrypted_value + .as_object_mut() + .ok_or(EventError::NotAnObject)?; + + // TODO better number conversion here. + let server_ts = event + .origin_server_ts + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + let server_ts: i64 = server_ts.try_into().unwrap_or_default(); + + decrypted_object.insert("sender".to_owned(), event.sender.to_string().into()); + decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into()); + decrypted_object.insert("origin_server_ts".to_owned(), server_ts.into()); + + decrypted_object.insert( + "unsigned".to_owned(), + serde_json::to_value(&event.unsigned).unwrap_or_default(), + ); + + Ok(( + serde_json::from_value::>(decrypted_value)?, + message_index, + )) + } } // #[cfg_attr(tarpaulin, skip)] @@ -996,6 +1045,9 @@ pub(crate) mod test { let plaintext = "This is a secret to everybody".to_owned(); let ciphertext = outbound.encrypt_helper(plaintext.clone()).await; - assert_eq!(plaintext, inbound.decrypt(ciphertext).await.unwrap().0); + assert_eq!( + plaintext, + inbound.decrypt_helper(ciphertext).await.unwrap().0 + ); } }