diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index b36dd68e..2510edb9 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -198,46 +198,17 @@ impl DeviceStore { #[cfg(test)] mod test { - use std::collections::HashMap; use std::convert::TryFrom; - use crate::api::r0::keys::SignedKey; use crate::crypto::device::test::get_device; use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore}; - use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; + use crate::crypto::olm::test::get_account_and_session; + use crate::crypto::olm::{InboundGroupSession, OutboundGroupSession}; use crate::identifiers::RoomId; - async fn get_account_and_session() -> (Account, Session) { - let alice = Account::new(); - - let bob = Account::new(); - - bob.generate_one_time_keys(1).await; - let one_time_key = bob - .one_time_keys() - .await - .curve25519() - .iter() - .nth(0) - .unwrap() - .1 - .to_owned(); - let one_time_key = SignedKey { - key: one_time_key, - signatures: HashMap::new(), - }; - let sender_key = bob.identity_keys().curve25519().to_owned(); - let session = alice - .create_outbound_session(&sender_key, &one_time_key) - .await - .unwrap(); - - (alice, session) - } - #[tokio::test] async fn test_session_store() { - let (account, session) = get_account_and_session().await; + let (_, session) = get_account_and_session().await; let mut store = SessionStore::new(); @@ -254,7 +225,7 @@ mod test { #[tokio::test] async fn test_session_store_bulk_storing() { - let (account, session) = get_account_and_session().await; + let (_, session) = get_account_and_session().await; let mut store = SessionStore::new(); store.set_for_sender(&session.sender_key, vec![session.clone()]); @@ -269,7 +240,6 @@ mod test { #[tokio::test] async fn test_group_session_store() { - let alice = Account::new(); let room_id = RoomId::try_from("!test:localhost").unwrap(); let outbound = OutboundGroupSession::new(&room_id); diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 3dff2085..614f5bda 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -613,14 +613,42 @@ impl std::fmt::Debug for OutboundGroupSession { } #[cfg(test)] -mod test { - use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession}; +pub(crate) mod test { + use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; use crate::identifiers::RoomId; use olm_rs::session::OlmMessage; use ruma_client_api::r0::keys::SignedKey; use std::collections::HashMap; use std::convert::TryFrom; + pub(crate) async fn get_account_and_session() -> (Account, Session) { + let alice = Account::new(); + + let bob = Account::new(); + + bob.generate_one_time_keys(1).await; + let one_time_key = bob + .one_time_keys() + .await + .curve25519() + .iter() + .nth(0) + .unwrap() + .1 + .to_owned(); + let one_time_key = SignedKey { + key: one_time_key, + signatures: HashMap::new(), + }; + let sender_key = bob.identity_keys().curve25519().to_owned(); + let session = alice + .create_outbound_session(&sender_key, &one_time_key) + .await + .unwrap(); + + (alice, session) + } + #[test] fn account_creation() { let account = Account::new(); @@ -724,7 +752,6 @@ mod test { #[tokio::test] async fn group_session_creation() { - let alice = Account::new(); let room_id = RoomId::try_from("!test:localhost").unwrap(); let outbound = OutboundGroupSession::new(&room_id); diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 0e844177..0e889786 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -97,3 +97,102 @@ impl CryptoStore for MemoryStore { Ok(()) } } + +#[cfg(test)] +mod test { + use std::convert::TryFrom; + + use crate::crypto::device::test::get_device; + use crate::crypto::olm::test::get_account_and_session; + use crate::crypto::olm::{InboundGroupSession, OutboundGroupSession}; + use crate::crypto::store::memorystore::MemoryStore; + use crate::crypto::store::CryptoStore; + use crate::identifiers::RoomId; + + #[tokio::test] + async fn test_session_store() { + let (account, session) = get_account_and_session().await; + let mut store = MemoryStore::new(); + + assert!(store.load_account().await.unwrap().is_none()); + store.save_account(account).await.unwrap(); + + store.save_session(session.clone()).await.unwrap(); + + let sessions = store + .get_sessions(&session.sender_key) + .await + .unwrap() + .unwrap(); + let sessions = sessions.lock().await; + + let loaded_session = &sessions[0]; + + assert_eq!(&session, loaded_session); + } + + #[tokio::test] + async fn test_group_session_store() { + let room_id = RoomId::try_from("!test:localhost").unwrap(); + + let outbound = OutboundGroupSession::new(&room_id); + let inbound = InboundGroupSession::new( + "test_key", + "test_key", + &room_id, + outbound.session_key().await, + ) + .unwrap(); + + let mut store = MemoryStore::new(); + store + .save_inbound_group_session(inbound.clone()) + .await + .unwrap(); + + let loaded_session = store + .get_inbound_group_session(&room_id, "test_key", outbound.session_id()) + .await + .unwrap() + .unwrap(); + assert_eq!(inbound, loaded_session); + } + + #[tokio::test] + async fn test_device_store() { + let device = get_device(); + let store = MemoryStore::new(); + + store.save_device(device.clone()).await.unwrap(); + + let loaded_device = store + .get_device(device.user_id(), device.device_id()) + .await + .unwrap() + .unwrap(); + + assert_eq!(device, loaded_device); + + let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); + + assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); + assert_eq!(user_devices.devices().nth(0).unwrap(), &device); + + let loaded_device = user_devices.get(device.device_id()).unwrap(); + + assert_eq!(device, loaded_device); + } + + #[tokio::test] + async fn test_tracked_users() { + let device = get_device(); + let mut store = MemoryStore::new(); + + assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); + assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap()); + + let tracked_users = store.tracked_users(); + + tracked_users.contains(device.user_id()); + } +}