diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 659617cd..f8259873 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -13,17 +13,22 @@ // limitations under the License. use std::fmt; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; use std::time::Instant; +use tokio::sync::Mutex; + use olm_rs::account::{IdentityKeys, OlmAccount, OneTimeKeys}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::inbound_group_session::OlmInboundGroupSession; +use olm_rs::outbound_group_session::OlmOutboundGroupSession; use olm_rs::session::{OlmMessage, OlmSession, PreKeyMessage}; use olm_rs::PicklingMode; use ruma_client_api::r0::keys::SignedKey; -use crate::identifiers::{RoomId, UserId}; +use crate::identifiers::RoomId; pub struct Account { inner: OlmAccount, @@ -207,7 +212,6 @@ impl PartialEq for Session { } } -#[derive(Debug)] pub struct InboundGroupSession { inner: OlmInboundGroupSession, pub(crate) sender_key: String, @@ -266,6 +270,78 @@ impl InboundGroupSession { } } +impl fmt::Debug for InboundGroupSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InbounDGroupSession") + .field("session_id", &self.session_id()) + .finish() + } +} + +#[derive(Clone)] +pub struct OutboundGroupSession { + inner: Arc>, + session_id: Arc, + room_id: Arc, + creation_time: Arc, + message_count: Arc, + shared: Arc, +} + +impl OutboundGroupSession { + pub fn new(room_id: &RoomId) -> Self { + let session = OlmOutboundGroupSession::new(); + let session_id = session.session_id(); + OutboundGroupSession { + inner: Arc::new(Mutex::new(session)), + room_id: Arc::new(room_id.to_owned()), + session_id: Arc::new(session_id), + creation_time: Arc::new(Instant::now()), + message_count: Arc::new(AtomicUsize::new(0)), + shared: Arc::new(AtomicBool::new(false)), + } + } + + pub async fn encrypt(&self, plaintext: String) -> String { + let session = self.inner.lock().await; + session.encrypt(plaintext) + } + + pub fn expired(&self) -> bool { + // TODO implement this. + false + } + + pub fn mark_as_shared(&self) { + self.shared.store(true, Ordering::Relaxed); + } + + pub async fn session_key(&self) -> String { + let session = self.inner.lock().await; + session.session_key() + } + + pub fn session_id(&self) -> &str { + &self.session_id + } + + pub async fn message_index(&self) -> u32 { + let session = self.inner.lock().await; + session.session_message_index() + } +} + +impl std::fmt::Debug for OutboundGroupSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OutboundGroupSession") + .field("session_id", &self.session_id) + .field("room_id", &self.room_id) + .field("creation_time", &self.creation_time) + .field("message_count", &self.message_count) + .finish() + } +} + #[cfg(test)] mod test { use crate::crypto::olm::Account;