crypto: Add the MemoryStore back.

This commit is contained in:
Damir Jelić 2020-03-30 17:07:36 +02:00
parent b128a76c9c
commit ceeb685e1a
5 changed files with 148 additions and 49 deletions

View file

@ -20,8 +20,9 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use super::error::{OlmError, Result, SignatureError, VerificationResult};
use super::memory_stores::{GroupSessionStore, SessionStore};
use super::memory_stores::SessionStore;
use super::olm::{Account, InboundGroupSession, Session};
use super::store::memorystore::MemoryStore;
#[cfg(feature = "sqlite-cryptostore")]
use super::store::sqlite::SqliteStore;
use super::CryptoStore;
@ -68,11 +69,7 @@ pub struct OlmMachine {
/// Store for the encryption keys.
/// Persists all the encrytpion keys so a client can resume the session
/// without the need to create new keys.
store: Option<Box<dyn CryptoStore>>,
/// A cache of all the Olm sessions we know about.
sessions: SessionStore,
/// A cache of all the inbound group sessions we know about.
inbound_group_sessions: GroupSessionStore,
store: Box<dyn CryptoStore>,
}
impl OlmMachine {
@ -88,9 +85,7 @@ impl OlmMachine {
device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(Account::new())),
uploaded_signed_key_count: None,
store: None,
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
store: Box::new(MemoryStore::new()),
})
}
@ -122,9 +117,7 @@ impl OlmMachine {
device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(account)),
uploaded_signed_key_count: None,
store: Some(Box::new(store)),
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
store: Box::new(store),
})
}
@ -178,9 +171,7 @@ impl OlmMachine {
account.mark_keys_as_published();
drop(account);
if let Some(store) = self.store.as_mut() {
store.save_account(self.account.clone()).await?;
}
self.store.save_account(self.account.clone()).await?;
Ok(())
}
@ -404,7 +395,7 @@ impl OlmMachine {
sender_key: &str,
message: &OlmMessage,
) -> Result<Option<String>> {
let s = self.sessions.get(sender_key);
let s = self.store.get_sessions(sender_key).await?;
let sessions = if let Some(s) = s {
s
@ -412,7 +403,7 @@ impl OlmMachine {
return Ok(None);
};
for session in sessions {
for session in &*sessions.lock().await {
let mut matches = false;
let mut session_lock = session.lock().await;
@ -427,9 +418,7 @@ impl OlmMachine {
let ret = session_lock.decrypt(message.clone());
if let Ok(p) = ret {
if let Some(store) = self.store.as_mut() {
store.save_session(session.clone()).await?;
}
self.store.save_session(session.clone()).await?;
return Ok(Some(p));
} else {
if matches {
@ -459,13 +448,10 @@ impl OlmMachine {
};
let plaintext = session.decrypt(message)?;
self.sessions.add(session).await;
// TODO save the session
self.store.add_and_save_session(session).await?;
plaintext
};
// TODO convert the plaintext to a ruma event.
trace!("Successfully decrypted a Olm message: {}", plaintext);
Ok(serde_json::from_str::<EventResult<ToDeviceEvent>>(
&plaintext,
@ -511,7 +497,8 @@ impl OlmMachine {
.decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message)
.await?;
debug!("Decrypted a to-device event {:?}", decrypted_event);
self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)?;
self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)
.await?;
Ok(decrypted_event)
} else {
@ -520,7 +507,7 @@ impl OlmMachine {
}
}
fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> {
async fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> {
match event.content.algorithm {
Algorithm::MegolmV1AesSha2 => {
// TODO check for all the valid fields.
@ -535,8 +522,7 @@ impl OlmMachine {
&event.content.room_id.to_string(),
&event.content.session_key,
)?;
self.inbound_group_sessions.add(session);
// TODO save the session in the store.
self.store.save_inbound_group_session(session).await?;
Ok(())
}
_ => {
@ -558,7 +544,7 @@ impl OlmMachine {
// TODO
}
fn handle_decrypted_to_device_event(
async fn handle_decrypted_to_device_event(
&mut self,
sender_key: &str,
event: &EventResult<ToDeviceEvent>,
@ -571,7 +557,7 @@ impl OlmMachine {
};
match event {
ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e),
ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e).await,
ToDeviceEvent::ForwardedRoomKey(e) => self.add_forwarded_room_key(sender_key, e),
_ => {
warn!("Received a unexpected encrypted to-device event");
@ -645,7 +631,7 @@ impl OlmMachine {
}
pub async fn decrypt_room_event(
&self,
&mut self,
event: &EncryptedEvent,
) -> Result<EventResult<RoomEvent>> {
let content = match &event.content {
@ -655,11 +641,14 @@ impl OlmMachine {
let room_id = event.room_id.as_ref().unwrap();
let session = self.inbound_group_sessions.get(
&room_id.to_string(),
&content.sender_key,
&content.session_id,
);
let session = self
.store
.get_inbound_group_session(
&room_id.to_string(),
&content.sender_key,
&content.session_id,
)
.await?;
// TODO check if the olm session is wedged and re-request the key.
let session = session.ok_or(OlmError::MissingSession)?;

View file

@ -21,7 +21,7 @@ use super::olm::{InboundGroupSession, Session};
#[derive(Debug)]
pub struct SessionStore {
entries: HashMap<String, Vec<Arc<Mutex<Session>>>>,
entries: HashMap<String, Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>,
}
impl SessionStore {
@ -33,19 +33,22 @@ impl SessionStore {
pub async fn add(&mut self, session: Session) {
if !self.entries.contains_key(&session.sender_key) {
self.entries
.insert(session.sender_key.to_owned(), Vec::new());
self.entries.insert(
session.sender_key.to_owned(),
Arc::new(Mutex::new(Vec::new())),
);
}
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
sessions.push(Arc::new(Mutex::new(session)));
sessions.lock().await.push(Arc::new(Mutex::new(session)));
}
pub fn get(&self, sender_key: &str) -> Option<&Vec<Arc<Mutex<Session>>>> {
self.entries.get(sender_key)
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> {
self.entries.get(sender_key).cloned()
}
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Arc<Mutex<Session>>>) {
self.entries.insert(sender_key.to_owned(), sessions);
self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
}
}

View file

@ -0,0 +1,79 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Mutex;
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session};
use crate::crypto::memory_stores::{GroupSessionStore, SessionStore};
#[derive(Debug)]
pub struct MemoryStore {
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
}
impl MemoryStore {
pub fn new() -> Self {
MemoryStore {
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
}
}
}
#[async_trait]
impl CryptoStore for MemoryStore {
async fn load_account(&mut self) -> Result<Option<Account>> {
Ok(None)
}
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
Ok(())
}
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()> {
Ok(())
}
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
self.sessions.add(session).await;
Ok(())
}
async fn get_sessions(
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
Ok(self.sessions.get(sender_key))
}
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
self.inbound_group_sessions.add(session);
Ok(())
}
async fn get_inbound_group_session(
&mut self,
room_id: &str,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
Ok(self
.inbound_group_sessions
.get(room_id, sender_key, session_id))
}
}

View file

@ -28,6 +28,7 @@ use super::olm::{Account, InboundGroupSession, Session};
use olm_rs::errors::{OlmAccountError, OlmSessionError};
use olm_rs::PicklingMode;
pub mod memorystore;
#[cfg(feature = "sqlite-cryptostore")]
pub mod sqlite;
@ -64,6 +65,16 @@ pub trait CryptoStore: Debug + Send + Sync {
async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
async fn get_sessions(&mut self, sender_key: &str)
-> Result<Option<&Vec<Arc<Mutex<Session>>>>>;
async fn add_and_save_session(&mut self, session: Session) -> Result<()>;
async fn get_sessions(
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>;
async fn get_inbound_group_session(
&mut self,
room_id: &str,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
}

View file

@ -149,7 +149,7 @@ impl SqliteStore {
async fn get_sessions_for(
&mut self,
sender_key: &str,
) -> Result<Option<&Vec<Arc<Mutex<Session>>>>> {
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions {
@ -332,12 +332,29 @@ impl CryptoStore for SqliteStore {
Ok(())
}
async fn get_sessions<'a>(
&'a mut self,
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
todo!()
}
async fn get_sessions(
&mut self,
sender_key: &str,
) -> Result<Option<&'a Vec<Arc<Mutex<Session>>>>> {
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
Ok(self.get_sessions_for(sender_key).await?)
}
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
todo!()
}
async fn get_inbound_group_session(
&mut self,
room_id: &str,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
todo!()
}
}
impl std::fmt::Debug for SqliteStore {