crypto: Add the MemoryStore back.

master
Damir Jelić 2020-03-30 17:07:36 +02:00
parent b128a76c9c
commit ceeb685e1a
5 changed files with 148 additions and 49 deletions

View File

@ -20,8 +20,9 @@ use std::result::Result as StdResult;
use std::sync::Arc; use std::sync::Arc;
use super::error::{OlmError, Result, SignatureError, VerificationResult}; use super::error::{OlmError, Result, SignatureError, VerificationResult};
use super::memory_stores::{GroupSessionStore, SessionStore}; use super::memory_stores::SessionStore;
use super::olm::{Account, InboundGroupSession, Session}; use super::olm::{Account, InboundGroupSession, Session};
use super::store::memorystore::MemoryStore;
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
use super::store::sqlite::SqliteStore; use super::store::sqlite::SqliteStore;
use super::CryptoStore; use super::CryptoStore;
@ -68,11 +69,7 @@ pub struct OlmMachine {
/// Store for the encryption keys. /// Store for the encryption keys.
/// Persists all the encrytpion keys so a client can resume the session /// Persists all the encrytpion keys so a client can resume the session
/// without the need to create new keys. /// without the need to create new keys.
store: Option<Box<dyn CryptoStore>>, store: Box<dyn CryptoStore>,
/// A cache of all the Olm sessions we know about.
sessions: SessionStore,
/// A cache of all the inbound group sessions we know about.
inbound_group_sessions: GroupSessionStore,
} }
impl OlmMachine { impl OlmMachine {
@ -88,9 +85,7 @@ impl OlmMachine {
device_id: device_id.to_owned(), device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(Account::new())), account: Arc::new(Mutex::new(Account::new())),
uploaded_signed_key_count: None, uploaded_signed_key_count: None,
store: None, store: Box::new(MemoryStore::new()),
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
}) })
} }
@ -122,9 +117,7 @@ impl OlmMachine {
device_id: device_id.to_owned(), device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(account)), account: Arc::new(Mutex::new(account)),
uploaded_signed_key_count: None, uploaded_signed_key_count: None,
store: Some(Box::new(store)), store: Box::new(store),
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
}) })
} }
@ -178,9 +171,7 @@ impl OlmMachine {
account.mark_keys_as_published(); account.mark_keys_as_published();
drop(account); drop(account);
if let Some(store) = self.store.as_mut() { self.store.save_account(self.account.clone()).await?;
store.save_account(self.account.clone()).await?;
}
Ok(()) Ok(())
} }
@ -404,7 +395,7 @@ impl OlmMachine {
sender_key: &str, sender_key: &str,
message: &OlmMessage, message: &OlmMessage,
) -> Result<Option<String>> { ) -> Result<Option<String>> {
let s = self.sessions.get(sender_key); let s = self.store.get_sessions(sender_key).await?;
let sessions = if let Some(s) = s { let sessions = if let Some(s) = s {
s s
@ -412,7 +403,7 @@ impl OlmMachine {
return Ok(None); return Ok(None);
}; };
for session in sessions { for session in &*sessions.lock().await {
let mut matches = false; let mut matches = false;
let mut session_lock = session.lock().await; let mut session_lock = session.lock().await;
@ -427,9 +418,7 @@ impl OlmMachine {
let ret = session_lock.decrypt(message.clone()); let ret = session_lock.decrypt(message.clone());
if let Ok(p) = ret { if let Ok(p) = ret {
if let Some(store) = self.store.as_mut() { self.store.save_session(session.clone()).await?;
store.save_session(session.clone()).await?;
}
return Ok(Some(p)); return Ok(Some(p));
} else { } else {
if matches { if matches {
@ -459,13 +448,10 @@ impl OlmMachine {
}; };
let plaintext = session.decrypt(message)?; let plaintext = session.decrypt(message)?;
self.sessions.add(session).await; self.store.add_and_save_session(session).await?;
// TODO save the session
plaintext plaintext
}; };
// TODO convert the plaintext to a ruma event.
trace!("Successfully decrypted a Olm message: {}", plaintext); trace!("Successfully decrypted a Olm message: {}", plaintext);
Ok(serde_json::from_str::<EventResult<ToDeviceEvent>>( Ok(serde_json::from_str::<EventResult<ToDeviceEvent>>(
&plaintext, &plaintext,
@ -511,7 +497,8 @@ impl OlmMachine {
.decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message) .decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message)
.await?; .await?;
debug!("Decrypted a to-device event {:?}", decrypted_event); debug!("Decrypted a to-device event {:?}", decrypted_event);
self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)?; self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)
.await?;
Ok(decrypted_event) Ok(decrypted_event)
} else { } else {
@ -520,7 +507,7 @@ impl OlmMachine {
} }
} }
fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> { async fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> {
match event.content.algorithm { match event.content.algorithm {
Algorithm::MegolmV1AesSha2 => { Algorithm::MegolmV1AesSha2 => {
// TODO check for all the valid fields. // TODO check for all the valid fields.
@ -535,8 +522,7 @@ impl OlmMachine {
&event.content.room_id.to_string(), &event.content.room_id.to_string(),
&event.content.session_key, &event.content.session_key,
)?; )?;
self.inbound_group_sessions.add(session); self.store.save_inbound_group_session(session).await?;
// TODO save the session in the store.
Ok(()) Ok(())
} }
_ => { _ => {
@ -558,7 +544,7 @@ impl OlmMachine {
// TODO // TODO
} }
fn handle_decrypted_to_device_event( async fn handle_decrypted_to_device_event(
&mut self, &mut self,
sender_key: &str, sender_key: &str,
event: &EventResult<ToDeviceEvent>, event: &EventResult<ToDeviceEvent>,
@ -571,7 +557,7 @@ impl OlmMachine {
}; };
match event { match event {
ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e), ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e).await,
ToDeviceEvent::ForwardedRoomKey(e) => self.add_forwarded_room_key(sender_key, e), ToDeviceEvent::ForwardedRoomKey(e) => self.add_forwarded_room_key(sender_key, e),
_ => { _ => {
warn!("Received a unexpected encrypted to-device event"); warn!("Received a unexpected encrypted to-device event");
@ -645,7 +631,7 @@ impl OlmMachine {
} }
pub async fn decrypt_room_event( pub async fn decrypt_room_event(
&self, &mut self,
event: &EncryptedEvent, event: &EncryptedEvent,
) -> Result<EventResult<RoomEvent>> { ) -> Result<EventResult<RoomEvent>> {
let content = match &event.content { let content = match &event.content {
@ -655,11 +641,14 @@ impl OlmMachine {
let room_id = event.room_id.as_ref().unwrap(); let room_id = event.room_id.as_ref().unwrap();
let session = self.inbound_group_sessions.get( let session = self
.store
.get_inbound_group_session(
&room_id.to_string(), &room_id.to_string(),
&content.sender_key, &content.sender_key,
&content.session_id, &content.session_id,
); )
.await?;
// TODO check if the olm session is wedged and re-request the key. // TODO check if the olm session is wedged and re-request the key.
let session = session.ok_or(OlmError::MissingSession)?; let session = session.ok_or(OlmError::MissingSession)?;

View File

@ -21,7 +21,7 @@ use super::olm::{InboundGroupSession, Session};
#[derive(Debug)] #[derive(Debug)]
pub struct SessionStore { pub struct SessionStore {
entries: HashMap<String, Vec<Arc<Mutex<Session>>>>, entries: HashMap<String, Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>,
} }
impl SessionStore { impl SessionStore {
@ -33,19 +33,22 @@ impl SessionStore {
pub async fn add(&mut self, session: Session) { pub async fn add(&mut self, session: Session) {
if !self.entries.contains_key(&session.sender_key) { if !self.entries.contains_key(&session.sender_key) {
self.entries self.entries.insert(
.insert(session.sender_key.to_owned(), Vec::new()); session.sender_key.to_owned(),
Arc::new(Mutex::new(Vec::new())),
);
} }
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap(); let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
sessions.push(Arc::new(Mutex::new(session))); sessions.lock().await.push(Arc::new(Mutex::new(session)));
} }
pub fn get(&self, sender_key: &str) -> Option<&Vec<Arc<Mutex<Session>>>> { pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> {
self.entries.get(sender_key) self.entries.get(sender_key).cloned()
} }
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Arc<Mutex<Session>>>) { pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Arc<Mutex<Session>>>) {
self.entries.insert(sender_key.to_owned(), sessions); self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
} }
} }

View File

@ -0,0 +1,79 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Mutex;
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session};
use crate::crypto::memory_stores::{GroupSessionStore, SessionStore};
#[derive(Debug)]
pub struct MemoryStore {
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
}
impl MemoryStore {
pub fn new() -> Self {
MemoryStore {
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
}
}
}
#[async_trait]
impl CryptoStore for MemoryStore {
async fn load_account(&mut self) -> Result<Option<Account>> {
Ok(None)
}
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
Ok(())
}
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()> {
Ok(())
}
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
self.sessions.add(session).await;
Ok(())
}
async fn get_sessions(
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
Ok(self.sessions.get(sender_key))
}
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
self.inbound_group_sessions.add(session);
Ok(())
}
async fn get_inbound_group_session(
&mut self,
room_id: &str,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
Ok(self
.inbound_group_sessions
.get(room_id, sender_key, session_id))
}
}

View File

@ -28,6 +28,7 @@ use super::olm::{Account, InboundGroupSession, Session};
use olm_rs::errors::{OlmAccountError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmSessionError};
use olm_rs::PicklingMode; use olm_rs::PicklingMode;
pub mod memorystore;
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
pub mod sqlite; pub mod sqlite;
@ -64,6 +65,16 @@ pub trait CryptoStore: Debug + Send + Sync {
async fn load_account(&mut self) -> Result<Option<Account>>; async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>; async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>; async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
async fn get_sessions(&mut self, sender_key: &str) async fn add_and_save_session(&mut self, session: Session) -> Result<()>;
-> Result<Option<&Vec<Arc<Mutex<Session>>>>>; async fn get_sessions(
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>;
async fn get_inbound_group_session(
&mut self,
room_id: &str,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
} }

View File

@ -149,7 +149,7 @@ impl SqliteStore {
async fn get_sessions_for( async fn get_sessions_for(
&mut self, &mut self,
sender_key: &str, sender_key: &str,
) -> Result<Option<&Vec<Arc<Mutex<Session>>>>> { ) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
let loaded_sessions = self.sessions.get(sender_key).is_some(); let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions { if !loaded_sessions {
@ -332,12 +332,29 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn get_sessions<'a>( async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
&'a mut self, todo!()
}
async fn get_sessions(
&mut self,
sender_key: &str, sender_key: &str,
) -> Result<Option<&'a Vec<Arc<Mutex<Session>>>>> { ) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
Ok(self.get_sessions_for(sender_key).await?) Ok(self.get_sessions_for(sender_key).await?)
} }
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
todo!()
}
async fn get_inbound_group_session(
&mut self,
room_id: &str,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
todo!()
}
} }
impl std::fmt::Debug for SqliteStore { impl std::fmt::Debug for SqliteStore {