crypto: Add initial code for olm message decryption.

This commit is contained in:
Damir Jelić 2020-03-21 16:41:48 +01:00
parent 70dda32949
commit 4215f98e91
5 changed files with 179 additions and 10 deletions

View file

@ -25,6 +25,10 @@ pub enum OlmError {
Signature(#[from] SignatureError),
#[error("failed to read or write to the crypto store {0}")]
Store(#[from] CryptoStoreError),
#[error("decryption failed likely because a Olm session was wedged")]
SessionWedged,
#[error("the Olm message has a unsupported type")]
UnsupportedOlmType,
}
pub type VerificationResult<T> = std::result::Result<T, SignatureError>;

View file

@ -14,11 +14,12 @@
use std::collections::HashMap;
use std::convert::TryInto;
#[cfg(feature = "sqlite-cryptostore")]
use std::path::Path;
use std::result::Result as StdResult;
use std::sync::Arc;
use super::error::{Result, SignatureError, VerificationResult};
use super::error::{OlmError, Result, SignatureError, VerificationResult};
use super::olm::Account;
#[cfg(feature = "sqlite-cryptostore")]
use super::store::sqlite::SqliteStore;
@ -29,17 +30,19 @@ use crate::api;
use api::r0::keys;
use cjson;
use olm_rs::session::OlmMessage;
use olm_rs::utility::OlmUtility;
use serde_json::json;
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::{debug, info, instrument, warn};
use tracing::{debug, info, instrument, trace, warn};
use ruma_client_api::r0::keys::{
AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey,
};
use ruma_client_api::r0::sync::sync_events::IncomingResponse as SyncResponse;
use ruma_events::{
room::encrypted::EncryptedEventContent,
to_device::{AnyToDeviceEvent as ToDeviceEvent, ToDeviceEncrypted, ToDeviceRoomKeyRequest},
Algorithm, EventResult,
};
@ -383,6 +386,71 @@ impl OlmMachine {
Ok((device_keys, one_time_keys))
}
async fn try_decrypt_olm_event(
&mut self,
sender_key: &str,
message: &OlmMessage,
) -> Result<Option<String>> {
let mut s = self.store.sessions_mut(sender_key).await?;
let sessions = if let Some(s) = s {
s
} else {
return Ok(None);
};
for session in sessions.iter_mut() {
let mut matches = false;
if let OlmMessage::PreKey(m) = &message {
matches = session.matches(sender_key, m.clone()).unwrap();
if !matches {
continue;
}
}
let ret = session.decrypt(message.clone());
if let Ok(p) = ret {
// TODO save the session.
return Ok(Some(p));
} else {
if matches {
return Err(OlmError::SessionWedged);
}
}
}
Ok(None)
}
async fn decrypt_olm_message(
&mut self,
sender: &str,
sender_key: &str,
message: OlmMessage,
) -> Result<Option<ToDeviceEvent>> {
let plaintext = if let Some(p) = self.try_decrypt_olm_event(sender_key, &message).await? {
p
} else {
let mut session = match &message {
OlmMessage::Message(_) => return Err(OlmError::SessionWedged),
OlmMessage::PreKey(m) => {
let account = self.account.lock().await;
account
.create_inbound_session_from(sender_key, m.clone())
.unwrap()
}
};
session.decrypt(message).unwrap()
// TODO save the session
};
// TODO convert the plaintext to a ruma event.
todo!()
}
/// Decrypt a to-device event.
///
/// Returns a decrypted `ToDeviceEvent` if the decryption was successful,
@ -392,9 +460,31 @@ impl OlmMachine {
///
/// * `event` - The to-device event that should be decrypted.
#[instrument]
fn decrypt_to_device_event(&self, _: &ToDeviceEncrypted) -> StdResult<ToDeviceEvent, ()> {
async fn decrypt_to_device_event(
&self,
event: &ToDeviceEncrypted,
) -> Result<Option<ToDeviceEvent>> {
info!("Decrypting to-device event");
Err(())
let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content {
c
} else {
warn!("Error, unsupported encryption algorithm");
return Ok(None);
};
let identity_keys = self.account.lock().await.identity_keys();
let own_key = identity_keys.curve25519();
let own_ciphertext = content.ciphertext.get(own_key);
if let Some(ciphertext) = own_ciphertext {
let message_type: u8 = ciphertext.message_type.try_into().unwrap();
let message =
OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone())
.map_err(|_| OlmError::UnsupportedOlmType);
}
todo!()
}
fn handle_room_key_request(&self, _: &ToDeviceRoomKeyRequest) {
@ -405,7 +495,7 @@ impl OlmMachine {
// TODO handle to-device verification events here.
}
#[instrument]
#[instrument(skip(response))]
pub fn receive_sync_response(&mut self, response: &mut SyncResponse) {
let one_time_key_count = response
.device_one_time_keys_count

View file

@ -13,9 +13,11 @@
// limitations under the License.
use std::fmt;
use std::time::Instant;
use olm_rs::account::{IdentityKeys, OlmAccount, OneTimeKeys};
use olm_rs::errors::OlmAccountError;
use olm_rs::errors::{OlmAccountError, OlmSessionError};
use olm_rs::session::{OlmMessage, OlmSession, PreKeyMessage};
use olm_rs::PicklingMode;
pub struct Account {
@ -39,6 +41,7 @@ impl fmt::Debug for Account {
/// any synchronization. We're wrapping the whole Olm machine inside a Mutex to
/// get Sync for it
unsafe impl Send for Account {}
unsafe impl Send for Session {}
impl Account {
/// Create a new account.
@ -97,6 +100,24 @@ impl Account {
let acc = OlmAccount::unpickle(pickle, pickling_mode)?;
Ok(Account { inner: acc, shared })
}
pub fn create_inbound_session_from(
&self,
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<Session, OlmSessionError> {
let session = self
.inner
.create_inbound_session_from(their_identity_key, message)?;
let now = Instant::now();
Ok(Session {
inner: session,
creation_time: now.clone(),
last_use_time: now,
})
}
}
impl PartialEq for Account {
@ -105,6 +126,29 @@ impl PartialEq for Account {
}
}
pub struct Session {
inner: OlmSession,
creation_time: Instant,
last_use_time: Instant,
}
impl Session {
pub fn decrypt(&mut self, message: OlmMessage) -> Result<String, OlmSessionError> {
let plaintext = self.inner.decrypt(message)?;
self.last_use_time = Instant::now();
Ok(plaintext)
}
pub fn matches(
&self,
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<bool, OlmSessionError> {
self.inner
.matches_inbound_session_from(their_identity_key, message)
}
}
#[cfg(test)]
mod test {
use crate::crypto::olm::Account;

View file

@ -13,7 +13,9 @@
// limitations under the License.
use core::fmt::Debug;
use std::collections::HashMap;
use std::io::Error as IoError;
use std::result::Result as StdResult;
use std::sync::Arc;
use url::ParseError;
@ -21,7 +23,7 @@ use async_trait::async_trait;
use thiserror::Error;
use tokio::sync::Mutex;
use super::olm::Account;
use super::olm::{Account, Session};
use olm_rs::errors::OlmAccountError;
use olm_rs::PicklingMode;
@ -52,17 +54,21 @@ pub type Result<T> = std::result::Result<T, CryptoStoreError>;
pub trait CryptoStore: Debug {
async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
async fn sessions_mut(&mut self, sender_key: &str) -> Result<Option<&mut Vec<Session>>>;
}
#[derive(Debug)]
pub struct MemoryStore {
pub(crate) account_info: Option<(String, bool)>,
sessions: HashMap<String, Vec<Session>>,
}
impl MemoryStore {
/// Create a new empty memory store.
pub fn new() -> Self {
MemoryStore { account_info: None }
MemoryStore {
account_info: None,
sessions: HashMap::new(),
}
}
}
@ -86,4 +92,22 @@ impl CryptoStore for MemoryStore {
self.account_info = Some((pickle, acc.shared));
Ok(())
}
async fn sessions_mut<'a>(
&'a mut self,
sender_key: &str,
) -> Result<Option<&'a mut Vec<Session>>> {
Ok(self.sessions.get_mut(sender_key))
}
}
impl std::fmt::Debug for MemoryStore {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
write!(
fmt,
"MemoryStore {{ account_stored: {}, account shared: {} }}",
self.account_info.is_some(),
self.account_info.as_ref().map_or(false, |a| a.1)
)
}
}

View file

@ -23,7 +23,7 @@ use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConn
use tokio::sync::Mutex;
use zeroize::Zeroizing;
use super::{Account, CryptoStore, Result};
use super::{Account, CryptoStore, Result, Session};
pub struct SqliteStore {
user_id: Arc<String>,
@ -168,6 +168,13 @@ impl CryptoStore for SqliteStore {
Ok(())
}
async fn sessions_mut<'a>(
&'a mut self,
sender_key: &str,
) -> Result<Option<&'a mut Vec<Session>>> {
todo!()
}
}
impl std::fmt::Debug for SqliteStore {