crypto: Add some tests to our in-memory stores.

master
Damir Jelić 2020-04-15 15:32:58 +02:00
parent 202ab9b050
commit af73ebdf09
3 changed files with 108 additions and 7 deletions

View File

@ -22,19 +22,22 @@ use super::device::Device;
use super::olm::{InboundGroupSession, Session}; use super::olm::{InboundGroupSession, Session};
use crate::identifiers::{DeviceId, RoomId, UserId}; use crate::identifiers::{DeviceId, RoomId, UserId};
/// In-memory store for Olm Sessions.
#[derive(Debug)] #[derive(Debug)]
pub struct SessionStore { pub struct SessionStore {
entries: HashMap<String, Arc<Mutex<Vec<Session>>>>, entries: HashMap<String, Arc<Mutex<Vec<Session>>>>,
} }
impl SessionStore { impl SessionStore {
/// Create a new empty Session store.
pub fn new() -> Self { pub fn new() -> Self {
SessionStore { SessionStore {
entries: HashMap::new(), entries: HashMap::new(),
} }
} }
pub async fn add(&mut self, session: Session) -> Session { /// Add a session to the store.
pub async fn add(&mut self, session: Session) {
if !self.entries.contains_key(&*session.sender_key) { if !self.entries.contains_key(&*session.sender_key) {
self.entries.insert( self.entries.insert(
session.sender_key.to_string(), session.sender_key.to_string(),
@ -42,15 +45,15 @@ impl SessionStore {
); );
} }
let sessions = self.entries.get_mut(&*session.sender_key).unwrap(); let sessions = self.entries.get_mut(&*session.sender_key).unwrap();
sessions.lock().await.push(session.clone()); sessions.lock().await.push(session);
session
} }
/// Get all the sessions that belong to the given sender key.
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> { pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
self.entries.get(sender_key).cloned() self.entries.get(sender_key).cloned()
} }
/// Add a list of sessions belonging to the sender key.
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Session>) { pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Session>) {
self.entries self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions))); .insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
@ -58,17 +61,20 @@ impl SessionStore {
} }
#[derive(Debug)] #[derive(Debug)]
/// In-memory store that houlds inbound group sessions.
pub struct GroupSessionStore { pub struct GroupSessionStore {
entries: HashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>, entries: HashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>,
} }
impl GroupSessionStore { impl GroupSessionStore {
/// Create a new empty store.
pub fn new() -> Self { pub fn new() -> Self {
GroupSessionStore { GroupSessionStore {
entries: HashMap::new(), entries: HashMap::new(),
} }
} }
/// Add a inbound group session to the store.
pub fn add(&mut self, session: InboundGroupSession) -> bool { pub fn add(&mut self, session: InboundGroupSession) -> bool {
if !self.entries.contains_key(&session.room_id) { if !self.entries.contains_key(&session.room_id) {
let room_id = &*session.room_id; let room_id = &*session.room_id;
@ -88,6 +94,14 @@ impl GroupSessionStore {
ret.is_some() ret.is_some()
} }
/// Get a inbound group session from our store.
///
/// # Arguments
/// * `room_id` - The room id of the room that the session belongs to.
///
/// * `sender_key` - The sender key that sent us the session.
///
/// * `session_id` - The unique id of the session.
pub fn get( pub fn get(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
@ -158,3 +172,86 @@ impl DeviceStore {
} }
} }
} }
#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::convert::TryFrom;
use crate::api::r0::keys::SignedKey;
use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore};
use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession, Session};
use crate::identifiers::RoomId;
async fn get_account_and_session() -> (Account, Session) {
let alice = Account::new();
let bob = Account::new();
bob.generate_one_time_keys(1).await;
let one_time_key = bob
.one_time_keys()
.await
.curve25519()
.iter()
.nth(0)
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey {
key: one_time_key,
signatures: HashMap::new(),
};
let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice
.create_outbound_session(&sender_key, &one_time_key)
.await
.unwrap();
(alice, session)
}
#[tokio::test]
async fn test_session_store() {
let (account, session) = get_account_and_session().await;
let mut store = SessionStore::new();
store.add(session.clone()).await;
let sessions = store.get(&session.sender_key).unwrap();
let sessions = sessions.lock().await;
let loaded_session = &sessions[0];
assert_eq!(&session, loaded_session);
}
#[tokio::test]
async fn test_group_session_store() {
let alice = Account::new();
let room_id = RoomId::try_from("!test:localhost").unwrap();
let outbound = OutboundGroupSession::new(&room_id);
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
outbound.mark_as_shared();
assert!(outbound.shared());
let inbound = InboundGroupSession::new(
"test_key",
"test_key",
&room_id,
outbound.session_key().await,
)
.unwrap();
let mut store = GroupSessionStore::new();
store.add(inbound.clone());
let loaded_session = store
.get(&room_id, "test_key", outbound.session_id())
.unwrap();
assert_eq!(inbound, loaded_session);
}
}

View File

@ -497,6 +497,12 @@ impl fmt::Debug for InboundGroupSession {
} }
} }
impl PartialEq for InboundGroupSession {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}
/// Outbound group session. /// Outbound group session.
/// ///
/// Outbound group sessions are used to exchange room messages between a group /// Outbound group sessions are used to exchange room messages between a group

View File

@ -350,7 +350,7 @@ impl CryptoStore for SqliteStore {
} }
async fn add_and_save_session(&mut self, session: Session) -> Result<()> { async fn add_and_save_session(&mut self, session: Session) -> Result<()> {
let session = self.sessions.add(session).await; self.sessions.add(session.clone()).await;
self.save_session(session).await?; self.save_session(session).await?;
Ok(()) Ok(())
} }
@ -435,9 +435,7 @@ mod test {
use olm_rs::outbound_group_session::OlmOutboundGroupSession; use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use ruma_client_api::r0::keys::SignedKey; use ruma_client_api::r0::keys::SignedKey;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use tempfile::tempdir; use tempfile::tempdir;
use tokio::sync::Mutex;
use super::{ use super::{
Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom, UserId, Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom, UserId,