crypto: Add the MemoryStore back.
parent
b128a76c9c
commit
ceeb685e1a
|
@ -20,8 +20,9 @@ use std::result::Result as StdResult;
|
|||
use std::sync::Arc;
|
||||
|
||||
use super::error::{OlmError, Result, SignatureError, VerificationResult};
|
||||
use super::memory_stores::{GroupSessionStore, SessionStore};
|
||||
use super::memory_stores::SessionStore;
|
||||
use super::olm::{Account, InboundGroupSession, Session};
|
||||
use super::store::memorystore::MemoryStore;
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
use super::store::sqlite::SqliteStore;
|
||||
use super::CryptoStore;
|
||||
|
@ -68,11 +69,7 @@ pub struct OlmMachine {
|
|||
/// Store for the encryption keys.
|
||||
/// Persists all the encrytpion keys so a client can resume the session
|
||||
/// without the need to create new keys.
|
||||
store: Option<Box<dyn CryptoStore>>,
|
||||
/// A cache of all the Olm sessions we know about.
|
||||
sessions: SessionStore,
|
||||
/// A cache of all the inbound group sessions we know about.
|
||||
inbound_group_sessions: GroupSessionStore,
|
||||
store: Box<dyn CryptoStore>,
|
||||
}
|
||||
|
||||
impl OlmMachine {
|
||||
|
@ -88,9 +85,7 @@ impl OlmMachine {
|
|||
device_id: device_id.to_owned(),
|
||||
account: Arc::new(Mutex::new(Account::new())),
|
||||
uploaded_signed_key_count: None,
|
||||
store: None,
|
||||
sessions: SessionStore::new(),
|
||||
inbound_group_sessions: GroupSessionStore::new(),
|
||||
store: Box::new(MemoryStore::new()),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -122,9 +117,7 @@ impl OlmMachine {
|
|||
device_id: device_id.to_owned(),
|
||||
account: Arc::new(Mutex::new(account)),
|
||||
uploaded_signed_key_count: None,
|
||||
store: Some(Box::new(store)),
|
||||
sessions: SessionStore::new(),
|
||||
inbound_group_sessions: GroupSessionStore::new(),
|
||||
store: Box::new(store),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -178,9 +171,7 @@ impl OlmMachine {
|
|||
account.mark_keys_as_published();
|
||||
drop(account);
|
||||
|
||||
if let Some(store) = self.store.as_mut() {
|
||||
store.save_account(self.account.clone()).await?;
|
||||
}
|
||||
self.store.save_account(self.account.clone()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -404,7 +395,7 @@ impl OlmMachine {
|
|||
sender_key: &str,
|
||||
message: &OlmMessage,
|
||||
) -> Result<Option<String>> {
|
||||
let s = self.sessions.get(sender_key);
|
||||
let s = self.store.get_sessions(sender_key).await?;
|
||||
|
||||
let sessions = if let Some(s) = s {
|
||||
s
|
||||
|
@ -412,7 +403,7 @@ impl OlmMachine {
|
|||
return Ok(None);
|
||||
};
|
||||
|
||||
for session in sessions {
|
||||
for session in &*sessions.lock().await {
|
||||
let mut matches = false;
|
||||
|
||||
let mut session_lock = session.lock().await;
|
||||
|
@ -427,9 +418,7 @@ impl OlmMachine {
|
|||
let ret = session_lock.decrypt(message.clone());
|
||||
|
||||
if let Ok(p) = ret {
|
||||
if let Some(store) = self.store.as_mut() {
|
||||
store.save_session(session.clone()).await?;
|
||||
}
|
||||
self.store.save_session(session.clone()).await?;
|
||||
return Ok(Some(p));
|
||||
} else {
|
||||
if matches {
|
||||
|
@ -459,13 +448,10 @@ impl OlmMachine {
|
|||
};
|
||||
|
||||
let plaintext = session.decrypt(message)?;
|
||||
self.sessions.add(session).await;
|
||||
|
||||
// TODO save the session
|
||||
self.store.add_and_save_session(session).await?;
|
||||
plaintext
|
||||
};
|
||||
|
||||
// TODO convert the plaintext to a ruma event.
|
||||
trace!("Successfully decrypted a Olm message: {}", plaintext);
|
||||
Ok(serde_json::from_str::<EventResult<ToDeviceEvent>>(
|
||||
&plaintext,
|
||||
|
@ -511,7 +497,8 @@ impl OlmMachine {
|
|||
.decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message)
|
||||
.await?;
|
||||
debug!("Decrypted a to-device event {:?}", decrypted_event);
|
||||
self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)?;
|
||||
self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event)
|
||||
.await?;
|
||||
|
||||
Ok(decrypted_event)
|
||||
} else {
|
||||
|
@ -520,7 +507,7 @@ impl OlmMachine {
|
|||
}
|
||||
}
|
||||
|
||||
fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> {
|
||||
async fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> {
|
||||
match event.content.algorithm {
|
||||
Algorithm::MegolmV1AesSha2 => {
|
||||
// TODO check for all the valid fields.
|
||||
|
@ -535,8 +522,7 @@ impl OlmMachine {
|
|||
&event.content.room_id.to_string(),
|
||||
&event.content.session_key,
|
||||
)?;
|
||||
self.inbound_group_sessions.add(session);
|
||||
// TODO save the session in the store.
|
||||
self.store.save_inbound_group_session(session).await?;
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
|
@ -558,7 +544,7 @@ impl OlmMachine {
|
|||
// TODO
|
||||
}
|
||||
|
||||
fn handle_decrypted_to_device_event(
|
||||
async fn handle_decrypted_to_device_event(
|
||||
&mut self,
|
||||
sender_key: &str,
|
||||
event: &EventResult<ToDeviceEvent>,
|
||||
|
@ -571,7 +557,7 @@ impl OlmMachine {
|
|||
};
|
||||
|
||||
match event {
|
||||
ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e),
|
||||
ToDeviceEvent::RoomKey(e) => self.add_room_key(sender_key, e).await,
|
||||
ToDeviceEvent::ForwardedRoomKey(e) => self.add_forwarded_room_key(sender_key, e),
|
||||
_ => {
|
||||
warn!("Received a unexpected encrypted to-device event");
|
||||
|
@ -645,7 +631,7 @@ impl OlmMachine {
|
|||
}
|
||||
|
||||
pub async fn decrypt_room_event(
|
||||
&self,
|
||||
&mut self,
|
||||
event: &EncryptedEvent,
|
||||
) -> Result<EventResult<RoomEvent>> {
|
||||
let content = match &event.content {
|
||||
|
@ -655,11 +641,14 @@ impl OlmMachine {
|
|||
|
||||
let room_id = event.room_id.as_ref().unwrap();
|
||||
|
||||
let session = self.inbound_group_sessions.get(
|
||||
let session = self
|
||||
.store
|
||||
.get_inbound_group_session(
|
||||
&room_id.to_string(),
|
||||
&content.sender_key,
|
||||
&content.session_id,
|
||||
);
|
||||
)
|
||||
.await?;
|
||||
// TODO check if the olm session is wedged and re-request the key.
|
||||
let session = session.ok_or(OlmError::MissingSession)?;
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ use super::olm::{InboundGroupSession, Session};
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct SessionStore {
|
||||
entries: HashMap<String, Vec<Arc<Mutex<Session>>>>,
|
||||
entries: HashMap<String, Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
|
@ -33,19 +33,22 @@ impl SessionStore {
|
|||
|
||||
pub async fn add(&mut self, session: Session) {
|
||||
if !self.entries.contains_key(&session.sender_key) {
|
||||
self.entries
|
||||
.insert(session.sender_key.to_owned(), Vec::new());
|
||||
self.entries.insert(
|
||||
session.sender_key.to_owned(),
|
||||
Arc::new(Mutex::new(Vec::new())),
|
||||
);
|
||||
}
|
||||
let mut sessions = self.entries.get_mut(&session.sender_key).unwrap();
|
||||
sessions.push(Arc::new(Mutex::new(session)));
|
||||
sessions.lock().await.push(Arc::new(Mutex::new(session)));
|
||||
}
|
||||
|
||||
pub fn get(&self, sender_key: &str) -> Option<&Vec<Arc<Mutex<Session>>>> {
|
||||
self.entries.get(sender_key)
|
||||
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>> {
|
||||
self.entries.get(sender_key).cloned()
|
||||
}
|
||||
|
||||
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Arc<Mutex<Session>>>) {
|
||||
self.entries.insert(sender_key.to_owned(), sessions);
|
||||
self.entries
|
||||
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session};
|
||||
use crate::crypto::memory_stores::{GroupSessionStore, SessionStore};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryStore {
|
||||
sessions: SessionStore,
|
||||
inbound_group_sessions: GroupSessionStore,
|
||||
}
|
||||
|
||||
impl MemoryStore {
|
||||
pub fn new() -> Self {
|
||||
MemoryStore {
|
||||
sessions: SessionStore::new(),
|
||||
inbound_group_sessions: GroupSessionStore::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CryptoStore for MemoryStore {
|
||||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
|
||||
self.sessions.add(session).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_sessions(
|
||||
&mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
|
||||
Ok(self.sessions.get(sender_key))
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
||||
self.inbound_group_sessions.add(session);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
&mut self,
|
||||
room_id: &str,
|
||||
sender_key: &str,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
||||
Ok(self
|
||||
.inbound_group_sessions
|
||||
.get(room_id, sender_key, session_id))
|
||||
}
|
||||
}
|
|
@ -28,6 +28,7 @@ use super::olm::{Account, InboundGroupSession, Session};
|
|||
use olm_rs::errors::{OlmAccountError, OlmSessionError};
|
||||
use olm_rs::PicklingMode;
|
||||
|
||||
pub mod memorystore;
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
pub mod sqlite;
|
||||
|
||||
|
@ -64,6 +65,16 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
|
||||
async fn get_sessions(&mut self, sender_key: &str)
|
||||
-> Result<Option<&Vec<Arc<Mutex<Session>>>>>;
|
||||
async fn add_and_save_session(&mut self, session: Session) -> Result<()>;
|
||||
async fn get_sessions(
|
||||
&mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()>;
|
||||
async fn get_inbound_group_session(
|
||||
&mut self,
|
||||
room_id: &str,
|
||||
sender_key: &str,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
|
||||
}
|
||||
|
|
|
@ -149,7 +149,7 @@ impl SqliteStore {
|
|||
async fn get_sessions_for(
|
||||
&mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<&Vec<Arc<Mutex<Session>>>>> {
|
||||
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
|
||||
let loaded_sessions = self.sessions.get(sender_key).is_some();
|
||||
|
||||
if !loaded_sessions {
|
||||
|
@ -332,12 +332,29 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_sessions<'a>(
|
||||
&'a mut self,
|
||||
async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn get_sessions(
|
||||
&mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<&'a Vec<Arc<Mutex<Session>>>>> {
|
||||
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>> {
|
||||
Ok(self.get_sessions_for(sender_key).await?)
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
&mut self,
|
||||
room_id: &str,
|
||||
sender_key: &str,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SqliteStore {
|
||||
|
|
Loading…
Reference in New Issue