crypto: Make the session stores thread safe.

master
Damir Jelić 2020-03-26 11:22:40 +01:00
parent 50167e7988
commit abe13d7a2d
5 changed files with 50 additions and 30 deletions

View File

@ -398,7 +398,7 @@ impl OlmMachine {
sender_key: &str,
message: &OlmMessage,
) -> Result<Option<String>> {
let mut s = self.store.sessions_mut(sender_key).await?;
let mut s = self.sessions.get_mut(sender_key).await;
let sessions = if let Some(s) = s {
s
@ -406,7 +406,7 @@ impl OlmMachine {
return Ok(None);
};
for session in sessions.iter_mut() {
for session in sessions.lock().await.iter_mut() {
let mut matches = false;
if let OlmMessage::PreKey(m) = &message {
@ -448,8 +448,10 @@ impl OlmMachine {
}
};
session.decrypt(message)?
let plaintext = session.decrypt(message)?;
self.sessions.add(session).await;
// TODO save the session
plaintext
};
// TODO convert the plaintext to a ruma event.
@ -644,7 +646,7 @@ impl OlmMachine {
// TODO check if the olm session is wedged and re-request the key.
let session = session.ok_or(OlmError::MissingSession)?;
let (plaintext, _) = session.decrypt(content.ciphertext.clone())?;
let (plaintext, _) = session.lock().await.decrypt(content.ciphertext.clone())?;
// TODO check the message index.
// TODO check if this is from a verified device.

View File

@ -12,12 +12,46 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::olm::InboundGroupSession;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use super::olm::{InboundGroupSession, Session};
#[derive(Debug)]
pub struct SessionStore {
entries: Mutex<HashMap<String, Arc<Mutex<Vec<Session>>>>>,
}
impl SessionStore {
pub fn new() -> Self {
SessionStore {
entries: Mutex::new(HashMap::new()),
}
}
pub async fn add(&mut self, session: Session) {
let mut entries = self.entries.lock().await;
if !entries.contains_key(&session.sender_key) {
entries.insert(
session.sender_key.to_owned(),
Arc::new(Mutex::new(Vec::new())),
);
}
let mut sessions = entries.get_mut(&session.sender_key).unwrap();
sessions.lock().await.push(session);
}
pub async fn get_mut(&mut self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
self.entries.lock().await.get_mut(sender_key).cloned()
}
}
#[derive(Debug)]
pub struct GroupSessionStore {
entries: HashMap<String, HashMap<String, HashMap<String, InboundGroupSession>>>,
entries: HashMap<String, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
}
impl GroupSessionStore {
@ -40,7 +74,7 @@ impl GroupSessionStore {
}
let mut sender_map = room_map.get_mut(&session.sender_key).unwrap();
let ret = sender_map.insert(session.session_id(), session);
let ret = sender_map.insert(session.session_id(), Arc::new(Mutex::new(session)));
ret.is_some()
}
@ -50,9 +84,9 @@ impl GroupSessionStore {
room_id: &str,
sender_key: &str,
session_id: &str,
) -> Option<&InboundGroupSession> {
) -> Option<Arc<Mutex<InboundGroupSession>>> {
self.entries
.get(room_id)
.and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id)))
.and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned()))
}
}

View File

@ -116,6 +116,7 @@ impl Account {
Ok(Session {
inner: session,
sender_key: their_identity_key.to_owned(),
creation_time: now.clone(),
last_use_time: now,
})
@ -128,8 +129,10 @@ impl PartialEq for Account {
}
}
#[derive(Debug)]
pub struct Session {
inner: OlmSession,
pub(crate) sender_key: String,
creation_time: Instant,
last_use_time: Instant,
}

View File

@ -53,24 +53,19 @@ pub enum CryptoStoreError {
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
#[async_trait]
pub trait CryptoStore: Debug {
pub trait CryptoStore: Debug + Sync + Sync {
async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
async fn sessions_mut(&mut self, sender_key: &str) -> Result<Option<&mut Vec<Session>>>;
}
pub struct MemoryStore {
pub(crate) account_info: Option<(String, bool)>,
sessions: HashMap<String, Vec<Session>>,
}
impl MemoryStore {
/// Create a new empty memory store.
pub fn new() -> Self {
MemoryStore {
account_info: None,
sessions: HashMap::new(),
}
MemoryStore { account_info: None }
}
}
@ -94,13 +89,6 @@ impl CryptoStore for MemoryStore {
self.account_info = Some((pickle, acc.shared));
Ok(())
}
async fn sessions_mut<'a>(
&'a mut self,
sender_key: &str,
) -> Result<Option<&'a mut Vec<Session>>> {
Ok(self.sessions.get_mut(sender_key))
}
}
impl std::fmt::Debug for MemoryStore {

View File

@ -168,13 +168,6 @@ impl CryptoStore for SqliteStore {
Ok(())
}
async fn sessions_mut<'a>(
&'a mut self,
sender_key: &str,
) -> Result<Option<&'a mut Vec<Session>>> {
todo!()
}
}
impl std::fmt::Debug for SqliteStore {