diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index 413e1e06..2d9f5e11 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -22,19 +22,22 @@ use super::device::Device; use super::olm::{InboundGroupSession, Session}; use crate::identifiers::{DeviceId, RoomId, UserId}; +/// In-memory store for Olm Sessions. #[derive(Debug)] pub struct SessionStore { entries: HashMap>>>, } impl SessionStore { + /// Create a new empty Session store. pub fn new() -> Self { SessionStore { entries: HashMap::new(), } } - pub async fn add(&mut self, session: Session) -> Session { + /// Add a session to the store. + pub async fn add(&mut self, session: Session) { if !self.entries.contains_key(&*session.sender_key) { self.entries.insert( session.sender_key.to_string(), @@ -42,15 +45,15 @@ impl SessionStore { ); } let sessions = self.entries.get_mut(&*session.sender_key).unwrap(); - sessions.lock().await.push(session.clone()); - - session + sessions.lock().await.push(session); } + /// Get all the sessions that belong to the given sender key. pub fn get(&self, sender_key: &str) -> Option>>> { self.entries.get(sender_key).cloned() } + /// Add a list of sessions belonging to the sender key. pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec) { self.entries .insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions))); @@ -58,17 +61,20 @@ impl SessionStore { } #[derive(Debug)] +/// In-memory store that houlds inbound group sessions. pub struct GroupSessionStore { entries: HashMap>>, } impl GroupSessionStore { + /// Create a new empty store. pub fn new() -> Self { GroupSessionStore { entries: HashMap::new(), } } + /// Add a inbound group session to the store. pub fn add(&mut self, session: InboundGroupSession) -> bool { if !self.entries.contains_key(&session.room_id) { let room_id = &*session.room_id; @@ -88,6 +94,14 @@ impl GroupSessionStore { ret.is_some() } + /// Get a inbound group session from our store. + /// + /// # Arguments + /// * `room_id` - The room id of the room that the session belongs to. + /// + /// * `sender_key` - The sender key that sent us the session. + /// + /// * `session_id` - The unique id of the session. pub fn get( &self, room_id: &RoomId, @@ -158,3 +172,86 @@ impl DeviceStore { } } } + +#[cfg(test)] +mod test { + use std::collections::HashMap; + use std::convert::TryFrom; + + use crate::api::r0::keys::SignedKey; + use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore}; + use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; + use crate::identifiers::RoomId; + + async fn get_account_and_session() -> (Account, Session) { + let alice = Account::new(); + + let bob = Account::new(); + + bob.generate_one_time_keys(1).await; + let one_time_key = bob + .one_time_keys() + .await + .curve25519() + .iter() + .nth(0) + .unwrap() + .1 + .to_owned(); + let one_time_key = SignedKey { + key: one_time_key, + signatures: HashMap::new(), + }; + let sender_key = bob.identity_keys().curve25519().to_owned(); + let session = alice + .create_outbound_session(&sender_key, &one_time_key) + .await + .unwrap(); + + (alice, session) + } + + #[tokio::test] + async fn test_session_store() { + let (account, session) = get_account_and_session().await; + + let mut store = SessionStore::new(); + store.add(session.clone()).await; + + let sessions = store.get(&session.sender_key).unwrap(); + let sessions = sessions.lock().await; + + let loaded_session = &sessions[0]; + + assert_eq!(&session, loaded_session); + } + + #[tokio::test] + async fn test_group_session_store() { + let alice = Account::new(); + let room_id = RoomId::try_from("!test:localhost").unwrap(); + + let outbound = OutboundGroupSession::new(&room_id); + + assert_eq!(0, outbound.message_index().await); + assert!(!outbound.shared()); + outbound.mark_as_shared(); + assert!(outbound.shared()); + + let inbound = InboundGroupSession::new( + "test_key", + "test_key", + &room_id, + outbound.session_key().await, + ) + .unwrap(); + + let mut store = GroupSessionStore::new(); + store.add(inbound.clone()); + + let loaded_session = store + .get(&room_id, "test_key", outbound.session_id()) + .unwrap(); + assert_eq!(inbound, loaded_session); + } +} diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 72e28b2d..8d160bbd 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -497,6 +497,12 @@ impl fmt::Debug for InboundGroupSession { } } +impl PartialEq for InboundGroupSession { + fn eq(&self, other: &Self) -> bool { + self.session_id() == other.session_id() + } +} + /// Outbound group session. /// /// Outbound group sessions are used to exchange room messages between a group diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index cc4aea7c..99a859f4 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -350,7 +350,7 @@ impl CryptoStore for SqliteStore { } async fn add_and_save_session(&mut self, session: Session) -> Result<()> { - let session = self.sessions.add(session).await; + self.sessions.add(session.clone()).await; self.save_session(session).await?; Ok(()) } @@ -435,9 +435,7 @@ mod test { use olm_rs::outbound_group_session::OlmOutboundGroupSession; use ruma_client_api::r0::keys::SignedKey; use std::collections::HashMap; - use std::sync::Arc; use tempfile::tempdir; - use tokio::sync::Mutex; use super::{ Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom, UserId,