diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 8296137c..88a0bf95 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -20,8 +20,9 @@ use std::result::Result as StdResult; use std::sync::Arc; use super::error::{OlmError, Result, SignatureError, VerificationResult}; -use super::memory_stores::{GroupSessionStore, SessionStore}; +use super::memory_stores::SessionStore; use super::olm::{Account, InboundGroupSession, Session}; +use super::store::memorystore::MemoryStore; #[cfg(feature = "sqlite-cryptostore")] use super::store::sqlite::SqliteStore; use super::CryptoStore; @@ -68,11 +69,7 @@ pub struct OlmMachine { /// Store for the encryption keys. /// Persists all the encrytpion keys so a client can resume the session /// without the need to create new keys. - store: Option>, - /// A cache of all the Olm sessions we know about. - sessions: SessionStore, - /// A cache of all the inbound group sessions we know about. - inbound_group_sessions: GroupSessionStore, + store: Box, } impl OlmMachine { @@ -88,9 +85,7 @@ impl OlmMachine { device_id: device_id.to_owned(), account: Arc::new(Mutex::new(Account::new())), uploaded_signed_key_count: None, - store: None, - sessions: SessionStore::new(), - inbound_group_sessions: GroupSessionStore::new(), + store: Box::new(MemoryStore::new()), }) } @@ -122,9 +117,7 @@ impl OlmMachine { device_id: device_id.to_owned(), account: Arc::new(Mutex::new(account)), uploaded_signed_key_count: None, - store: Some(Box::new(store)), - sessions: SessionStore::new(), - inbound_group_sessions: GroupSessionStore::new(), + store: Box::new(store), }) } @@ -178,9 +171,7 @@ impl OlmMachine { account.mark_keys_as_published(); drop(account); - if let Some(store) = self.store.as_mut() { - store.save_account(self.account.clone()).await?; - } + self.store.save_account(self.account.clone()).await?; Ok(()) } @@ -404,7 +395,7 @@ impl OlmMachine { sender_key: &str, message: &OlmMessage, ) -> Result> { - let s = self.sessions.get(sender_key); + let s = self.store.get_sessions(sender_key).await?; let sessions = if let Some(s) = s { s @@ -412,7 +403,7 @@ impl OlmMachine { return Ok(None); }; - for session in sessions { + for session in &*sessions.lock().await { let mut matches = false; let mut session_lock = session.lock().await; @@ -427,9 +418,7 @@ impl OlmMachine { let ret = session_lock.decrypt(message.clone()); if let Ok(p) = ret { - if let Some(store) = self.store.as_mut() { - store.save_session(session.clone()).await?; - } + self.store.save_session(session.clone()).await?; return Ok(Some(p)); } else { if matches { @@ -459,13 +448,10 @@ impl OlmMachine { }; let plaintext = session.decrypt(message)?; - self.sessions.add(session).await; - - // TODO save the session + self.store.add_and_save_session(session).await?; plaintext }; - // TODO convert the plaintext to a ruma event. trace!("Successfully decrypted a Olm message: {}", plaintext); Ok(serde_json::from_str::>( &plaintext, @@ -511,7 +497,8 @@ impl OlmMachine { .decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message) .await?; debug!("Decrypted a to-device event {:?}", decrypted_event); - self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)?; + self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event) + .await?; Ok(decrypted_event) } else { @@ -520,7 +507,7 @@ impl OlmMachine { } } - fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> { + async fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> { match event.content.algorithm { Algorithm::MegolmV1AesSha2 => { // TODO check for all the valid fields. @@ -535,8 +522,7 @@ impl OlmMachine { &event.content.room_id.to_string(), &event.content.session_key, )?; - self.inbound_group_sessions.add(session); - // TODO save the session in the store. + self.store.save_inbound_group_session(session).await?; Ok(()) } _ => { @@ -558,7 +544,7 @@ impl OlmMachine { // TODO } - fn handle_decrypted_to_device_event( + async fn handle_decrypted_to_device_event( &mut self, sender_key: &str, event: &EventResult, @@ -571,7 +557,7 @@ impl OlmMachine { }; match event { - ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e), + ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e).await, ToDeviceEvent::ForwardedRoomKey(e) => self.add_forwarded_room_key(sender_key, e), _ => { warn!("Received a unexpected encrypted to-device event"); @@ -645,7 +631,7 @@ impl OlmMachine { } pub async fn decrypt_room_event( - &self, + &mut self, event: &EncryptedEvent, ) -> Result> { let content = match &event.content { @@ -655,11 +641,14 @@ impl OlmMachine { let room_id = event.room_id.as_ref().unwrap(); - let session = self.inbound_group_sessions.get( - &room_id.to_string(), - &content.sender_key, - &content.session_id, - ); + let session = self + .store + .get_inbound_group_session( + &room_id.to_string(), + &content.sender_key, + &content.session_id, + ) + .await?; // TODO check if the olm session is wedged and re-request the key. let session = session.ok_or(OlmError::MissingSession)?; diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 56b68177..5b7b5820 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -21,7 +21,7 @@ use super::olm::{InboundGroupSession, Session}; #[derive(Debug)] pub struct SessionStore { - entries: HashMap>>>, + entries: HashMap>>>>>, } impl SessionStore { @@ -33,19 +33,22 @@ impl SessionStore { pub async fn add(&mut self, session: Session) { if !self.entries.contains_key(&session.sender_key) { - self.entries - .insert(session.sender_key.to_owned(), Vec::new()); + self.entries.insert( + session.sender_key.to_owned(), + Arc::new(Mutex::new(Vec::new())), + ); } let mut sessions = self.entries.get_mut(&session.sender_key).unwrap(); - sessions.push(Arc::new(Mutex::new(session))); + sessions.lock().await.push(Arc::new(Mutex::new(session))); } - pub fn get(&self, sender_key: &str) -> Option<&Vec>>> { - self.entries.get(sender_key) + pub fn get(&self, sender_key: &str) -> Option>>>>> { + self.entries.get(sender_key).cloned() } pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec>>) { - self.entries.insert(sender_key.to_owned(), sessions); + self.entries + .insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions))); } } diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs new file mode 100644 index 00000000..f3744f79 --- /dev/null +++ b/src/crypto/store/memorystore.rs @@ -0,0 +1,79 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::Mutex; + +use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; +use crate::crypto::memory_stores::{GroupSessionStore, SessionStore}; + +#[derive(Debug)] +pub struct MemoryStore { + sessions: SessionStore, + inbound_group_sessions: GroupSessionStore, +} + +impl MemoryStore { + pub fn new() -> Self { + MemoryStore { + sessions: SessionStore::new(), + inbound_group_sessions: GroupSessionStore::new(), + } + } +} + +#[async_trait] +impl CryptoStore for MemoryStore { + async fn load_account(&mut self) -> Result> { + Ok(None) + } + + async fn save_account(&mut self, account: Arc>) -> Result<()> { + Ok(()) + } + + async fn save_session(&mut self, session: Arc>) -> Result<()> { + Ok(()) + } + + async fn add_and_save_session(&mut self, session: Session) -> Result<()> { + self.sessions.add(session).await; + Ok(()) + } + + async fn get_sessions( + &mut self, + sender_key: &str, + ) -> Result>>>>>> { + Ok(self.sessions.get(sender_key)) + } + + async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { + self.inbound_group_sessions.add(session); + Ok(()) + } + + async fn get_inbound_group_session( + &mut self, + room_id: &str, + sender_key: &str, + session_id: &str, + ) -> Result>>> { + Ok(self + .inbound_group_sessions + .get(room_id, sender_key, session_id)) + } +} diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index beb660db..ae2f1182 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -28,6 +28,7 @@ use super::olm::{Account, InboundGroupSession, Session}; use olm_rs::errors::{OlmAccountError, OlmSessionError}; use olm_rs::PicklingMode; +pub mod memorystore; #[cfg(feature = "sqlite-cryptostore")] pub mod sqlite; @@ -64,6 +65,16 @@ pub trait CryptoStore: Debug + Send + Sync { async fn load_account(&mut self) -> Result>; async fn save_account(&mut self, account: Arc>) -> Result<()>; async fn save_session(&mut self, session: Arc>) -> Result<()>; - async fn get_sessions(&mut self, sender_key: &str) - -> Result>>>>; + async fn add_and_save_session(&mut self, session: Session) -> Result<()>; + async fn get_sessions( + &mut self, + sender_key: &str, + ) -> Result>>>>>>; + async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>; + async fn get_inbound_group_session( + &mut self, + room_id: &str, + sender_key: &str, + session_id: &str, + ) -> Result>>>; } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 865f2c63..5cd57e8a 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -149,7 +149,7 @@ impl SqliteStore { async fn get_sessions_for( &mut self, sender_key: &str, - ) -> Result>>>> { + ) -> Result>>>>>> { let loaded_sessions = self.sessions.get(sender_key).is_some(); if !loaded_sessions { @@ -332,12 +332,29 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn get_sessions<'a>( - &'a mut self, + async fn add_and_save_session(&mut self, session: Session) -> Result<()> { + todo!() + } + + async fn get_sessions( + &mut self, sender_key: &str, - ) -> Result>>>> { + ) -> Result>>>>>> { Ok(self.get_sessions_for(sender_key).await?) } + + async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> { + todo!() + } + + async fn get_inbound_group_session( + &mut self, + room_id: &str, + sender_key: &str, + session_id: &str, + ) -> Result>>> { + todo!() + } } impl std::fmt::Debug for SqliteStore {