diff --git a/src/crypto/error.rs b/src/crypto/error.rs index dc51fc74..b03e8a14 100644 --- a/src/crypto/error.rs +++ b/src/crypto/error.rs @@ -13,7 +13,7 @@ // limitations under the License. use cjson::Error as CjsonError; -use olm_rs::errors::OlmSessionError; +use olm_rs::errors::{OlmGroupSessionError, OlmSessionError}; use serde_json::Error as SerdeError; use thiserror::Error; @@ -36,7 +36,9 @@ pub enum OlmError { #[error("the Encrypted message doesn't contain a ciphertext for our device")] MissingCiphertext, #[error("can't finish Olm Session operation {0}")] - OlmSessionError(#[from] OlmSessionError), + OlmSession(#[from] OlmSessionError), + #[error("can't finish Olm Session operation {0}")] + OlmGroupSession(#[from] OlmGroupSessionError), #[error("error deserializing a string to json")] JsonError(#[from] SerdeError), #[error("the provided JSON value isn't an object")] diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index afb1727a..b81ec861 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::error::{OlmError, Result, SignatureError, VerificationResult}; use super::memory_stores::GroupSessionStore; -use super::olm::Account; +use super::olm::{Account, InboundGroupSession}; #[cfg(feature = "sqlite-cryptostore")] use super::store::sqlite::SqliteStore; use super::store::MemoryStore; @@ -31,9 +31,7 @@ use crate::api; use api::r0::keys; use cjson; -use olm_rs::{ - inbound_group_session::OlmInboundGroupSession, session::OlmMessage, utility::OlmUtility, -}; +use olm_rs::{session::OlmMessage, utility::OlmUtility}; use serde_json::{json, Value}; use tokio::sync::Mutex; use tracing::{debug, info, instrument, trace, warn}; @@ -499,7 +497,7 @@ impl OlmMachine { .decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message) .await?; debug!("Decrypted a to-device event {:?}", decrypted_event); - self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event); + self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)?; Ok(decrypted_event) } else { @@ -508,22 +506,35 @@ impl OlmMachine { } } - fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) { + fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> { match event.content.algorithm { Algorithm::MegolmV1AesSha2 => { // TODO check for all the valid fields. - let session = OlmInboundGroupSession::new(&event.content.session_key).unwrap(); - self.inbound_group_sessions - .add(&event.sender.to_string(), sender_key, session); + let session = InboundGroupSession::new( + sender_key, + &event.content.room_id.to_string(), + &event.content.session_key, + )?; + self.inbound_group_sessions.add(session); + // TODO save the session in the store. + Ok(()) + } + _ => { + warn!( + "Received room key with unsupported key algorithm {}", + event.content.algorithm + ); + Ok(()) } - _ => warn!( - "Received room key with unsupported key algorithm {}", - event.content.algorithm - ), } } - fn add_forwarded_room_key(&self, event: &ToDeviceForwardedRoomKey) { + fn add_forwarded_room_key( + &self, + sender_key: &str, + event: &ToDeviceForwardedRoomKey, + ) -> Result<()> { + Ok(()) // TODO } @@ -531,18 +542,21 @@ impl OlmMachine { &mut self, sender_key: &str, event: &EventResult, - ) { + ) -> Result<()> { let event = if let EventResult::Ok(e) = event { e } else { warn!("Decrypted to-device event failed to be parsed correctly"); - return; + return Ok(()); }; match event { ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e), - ToDeviceEvent::ForwardedRoomKey(e) => self.add_forwarded_room_key(e), - _ => warn!("Received a unexpected encrypted to-device event"), + ToDeviceEvent::ForwardedRoomKey(e) => self.add_forwarded_room_key(sender_key, e), + _ => { + warn!("Received a unexpected encrypted to-device event"); + Ok(()) + } } } @@ -555,6 +569,13 @@ impl OlmMachine { } #[instrument(skip(response))] + /// Handle a sync response and update the internal state of the Olm machine. + /// + /// This will decrypt to-device events but will not touch room messages. + /// + /// # Arguments + /// + /// * `response` - The sync latest sync response. pub async fn receive_sync_response(&mut self, response: &mut SyncResponse) { let one_time_key_count = response .device_one_time_keys_count diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 8ae5e94f..6d0ac27c 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use olm_rs::inbound_group_session::OlmInboundGroupSession; +use super::olm::InboundGroupSession; use std::collections::HashMap; #[derive(Debug)] pub struct GroupSessionStore { - entries: HashMap>>, + entries: HashMap>>, } impl GroupSessionStore { @@ -27,23 +27,19 @@ impl GroupSessionStore { } } - pub fn add( - &mut self, - room_id: &str, - sender_key: &str, - session: OlmInboundGroupSession, - ) -> bool { - if !self.entries.contains_key(room_id) { - self.entries.insert(room_id.to_owned(), HashMap::new()); + pub fn add(&mut self, session: InboundGroupSession) -> bool { + if !self.entries.contains_key(&session.room_id) { + self.entries + .insert(session.room_id.to_owned(), HashMap::new()); } - let mut room_map = self.entries.get_mut(room_id).unwrap(); + let mut room_map = self.entries.get_mut(&session.room_id).unwrap(); - if !room_map.contains_key(sender_key) { - room_map.insert(sender_key.to_owned(), HashMap::new()); + if !room_map.contains_key(&session.sender_key) { + room_map.insert(session.sender_key.to_owned(), HashMap::new()); } - let mut sender_map = room_map.get_mut(sender_key).unwrap(); + let mut sender_map = room_map.get_mut(&session.sender_key).unwrap(); let ret = sender_map.insert(session.session_id(), session); ret.is_some() @@ -54,7 +50,7 @@ impl GroupSessionStore { room_id: &str, sender_key: &str, session_id: &str, - ) -> Option<&OlmInboundGroupSession> { + ) -> Option<&InboundGroupSession> { self.entries .get(room_id) .and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id))) diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index fd782130..a70fbe21 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -16,7 +16,8 @@ use std::fmt; use std::time::Instant; use olm_rs::account::{IdentityKeys, OlmAccount, OneTimeKeys}; -use olm_rs::errors::{OlmAccountError, OlmSessionError}; +use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; +use olm_rs::inbound_group_session::OlmInboundGroupSession; use olm_rs::session::{OlmMessage, OlmSession, PreKeyMessage}; use olm_rs::PicklingMode; @@ -42,6 +43,7 @@ impl fmt::Debug for Account { /// get Sync for it unsafe impl Send for Account {} unsafe impl Send for Session {} +unsafe impl Send for InboundGroupSession {} impl Account { /// Create a new account. @@ -149,6 +151,41 @@ impl Session { } } +#[derive(Debug)] +pub struct InboundGroupSession { + inner: OlmInboundGroupSession, + pub(crate) sender_key: String, + pub(crate) room_id: String, + forwarding_chains: Option>, +} + +impl InboundGroupSession { + pub fn new( + sender_key: &str, + room_id: &str, + session_key: &str, + ) -> Result { + Ok(InboundGroupSession { + inner: OlmInboundGroupSession::new(session_key)?, + sender_key: sender_key.to_owned(), + room_id: room_id.to_owned(), + forwarding_chains: None, + }) + } + + pub fn session_id(&self) -> String { + self.inner.session_id() + } + + pub fn first_known_index(&self) -> u32 { + self.inner.first_known_index() + } + + pub fn decrypt(&self, mut message: String) -> Result<(String, u32), OlmGroupSessionError> { + self.inner.decrypt(message) + } +} + #[cfg(test)] mod test { use crate::crypto::olm::Account;