diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 7ba0951e..8520b4f1 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -544,6 +544,9 @@ impl std::fmt::Debug for OutboundGroupSession { #[cfg(test)] mod test { use crate::crypto::olm::Account; + use olm_rs::session::OlmMessage; + use ruma_client_api::r0::keys::SignedKey; + use std::collections::HashMap; #[test] fn account_creation() { @@ -589,4 +592,52 @@ mod test { let one_time_keys = account.one_time_keys(); assert!(one_time_keys.curve25519().is_empty()); } + + #[test] + fn session_creation() { + let alice = Account::new(); + let bob = Account::new(); + let alice_keys = alice.identity_keys(); + let one_time_keys = alice.one_time_keys(); + + alice.generate_one_time_keys(1); + let one_time_keys = alice.one_time_keys(); + alice.mark_keys_as_published(); + + let one_time_key = one_time_keys + .curve25519() + .iter() + .nth(0) + .unwrap() + .1 + .to_owned(); + + let one_time_key = SignedKey { + key: one_time_key, + signatures: HashMap::new(), + }; + + let mut bob_session = bob + .create_outbound_session(alice_keys.curve25519(), &one_time_key) + .unwrap(); + + let plaintext = "Hello world"; + + let message = bob_session.encrypt(plaintext); + + let prekey_message = match message.clone() { + OlmMessage::PreKey(m) => m, + OlmMessage::Message(_) => panic!("Incorrect message type"), + }; + + let bob_keys = bob.identity_keys(); + let mut alice_session = alice + .create_inbound_session(bob_keys.curve25519(), prekey_message) + .unwrap(); + + assert_eq!(bob_session.session_id(), alice_session.session_id()); + + let decyrpted = alice_session.decrypt(message).unwrap(); + assert_eq!(plaintext, decyrpted); + } }