diff --git a/src/base_client.rs b/src/base_client.rs index 8813ecef..229a1629 100644 --- a/src/base_client.rs +++ b/src/base_client.rs @@ -320,7 +320,7 @@ impl Client { let mut olm = self.olm.lock().await; if let Some(o) = &mut *olm { - o.receive_sync_response(response); + o.receive_sync_response(response).await; } } } diff --git a/src/crypto/error.rs b/src/crypto/error.rs index c378c15a..dc51fc74 100644 --- a/src/crypto/error.rs +++ b/src/crypto/error.rs @@ -13,6 +13,8 @@ // limitations under the License. use cjson::Error as CjsonError; +use olm_rs::errors::OlmSessionError; +use serde_json::Error as SerdeError; use thiserror::Error; use super::store::CryptoStoreError; @@ -29,6 +31,16 @@ pub enum OlmError { SessionWedged, #[error("the Olm message has a unsupported type")] UnsupportedOlmType, + #[error("the Encrypted message has been encrypted with a unsupported algorithm.")] + UnsupportedAlgorithm, + #[error("the Encrypted message doesn't contain a ciphertext for our device")] + MissingCiphertext, + #[error("can't finish Olm Session operation {0}")] + OlmSessionError(#[from] OlmSessionError), + #[error("error deserializing a string to json")] + JsonError(#[from] SerdeError), + #[error("the provided JSON value isn't an object")] + NotAnObject, } pub type VerificationResult = std::result::Result; diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 371acc93..77a5a5c1 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -403,7 +403,7 @@ impl OlmMachine { let mut matches = false; if let OlmMessage::PreKey(m) = &message { - matches = session.matches(sender_key, m.clone()).unwrap(); + matches = session.matches(sender_key, m.clone())?; if !matches { continue; } @@ -429,7 +429,7 @@ impl OlmMachine { sender: &str, sender_key: &str, message: OlmMessage, - ) -> Result> { + ) -> Result> { let plaintext = if let Some(p) = self.try_decrypt_olm_event(sender_key, &message).await? { p } else { @@ -437,18 +437,24 @@ impl OlmMachine { OlmMessage::Message(_) => return Err(OlmError::SessionWedged), OlmMessage::PreKey(m) => { let account = self.account.lock().await; - account - .create_inbound_session_from(sender_key, m.clone()) - .unwrap() + account.create_inbound_session_from(sender_key, m.clone())? } }; - session.decrypt(message).unwrap() + session.decrypt(message)? // TODO save the session }; // TODO convert the plaintext to a ruma event. - todo!() + let mut json_plaintext = serde_json::from_str::(&plaintext)?; + let json_object = json_plaintext + .as_object_mut() + .ok_or(OlmError::NotAnObject)?; + json_object.insert("sender".to_owned(), sender.into()); + + Ok(serde_json::from_value::>( + json_plaintext, + )?) } /// Decrypt a to-device event. @@ -461,16 +467,16 @@ impl OlmMachine { /// * `event` - The to-device event that should be decrypted. #[instrument] async fn decrypt_to_device_event( - &self, + &mut self, event: &ToDeviceEncrypted, - ) -> Result> { + ) -> Result> { info!("Decrypting to-device event"); let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { c } else { warn!("Error, unsupported encryption algorithm"); - return Ok(None); + return Err(OlmError::UnsupportedAlgorithm); }; let identity_keys = self.account.lock().await.identity_keys(); @@ -478,13 +484,21 @@ impl OlmMachine { let own_ciphertext = content.ciphertext.get(own_key); if let Some(ciphertext) = own_ciphertext { - let message_type: u8 = ciphertext.message_type.try_into().unwrap(); + let message_type: u8 = ciphertext + .message_type + .try_into() + .map_err(|_| OlmError::UnsupportedOlmType)?; let message = OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone()) - .map_err(|_| OlmError::UnsupportedOlmType); - } + .map_err(|_| OlmError::UnsupportedOlmType)?; - todo!() + Ok(self + .decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message) + .await?) + } else { + warn!("Olm event doesn't contain a ciphertext for our key"); + Err(OlmError::MissingCiphertext) + } } fn handle_room_key_request(&self, _: &ToDeviceRoomKeyRequest) { @@ -496,7 +510,7 @@ impl OlmMachine { } #[instrument(skip(response))] - pub fn receive_sync_response(&mut self, response: &mut SyncResponse) { + pub async fn receive_sync_response(&mut self, response: &mut SyncResponse) { let one_time_key_count = response .device_one_time_keys_count .get(&keys::KeyAlgorithm::SignedCurve25519); @@ -519,7 +533,8 @@ impl OlmMachine { ToDeviceEvent::RoomEncrypted(e) => { // TODO put the decrypted event into a vec so we can replace // them in the sync response. - let _ = self.decrypt_to_device_event(e); + let decrypted_event = self.decrypt_to_device_event(e).await; + info!("Decrypted a to-device event {:?}", decrypted_event); } ToDeviceEvent::RoomKeyRequest(e) => self.handle_room_key_request(e), ToDeviceEvent::KeyVerificationAccept(..) diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index e9f0b887..c1d8ed10 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// TODO remove this. mod error; -#[allow(dead_code)] +// TODO remove this. mod machine; #[allow(dead_code)] mod olm; diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index 658258f8..7ce34854 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -24,28 +24,30 @@ use thiserror::Error; use tokio::sync::Mutex; use super::olm::{Account, Session}; -use olm_rs::errors::OlmAccountError; +use olm_rs::errors::{OlmAccountError, OlmSessionError}; use olm_rs::PicklingMode; #[cfg(feature = "sqlite-cryptostore")] pub mod sqlite; #[cfg(feature = "sqlite-cryptostore")] -use sqlx::Error as SqlxError; +use sqlx::{sqlite::Sqlite, Error as SqlxError}; #[derive(Error, Debug)] pub enum CryptoStoreError { #[error("can't read or write from the store")] Io(#[from] IoError), - #[error("can't finish Olm account operation {0}")] + #[error("can't finish Olm Account operation {0}")] OlmAccountError(#[from] OlmAccountError), + #[error("can't finish Olm Session operation {0}")] + OlmSessionError(#[from] OlmSessionError), #[error("URL can't be parsed")] UrlParse(#[from] ParseError), // TODO flatten the SqlxError to make it easier for other store // implementations. #[cfg(feature = "sqlite-cryptostore")] #[error("database error")] - DatabaseError(#[from] SqlxError), + DatabaseError(#[from] SqlxError), } pub type Result = std::result::Result;