crypto: Use a Read implementation for the attachment encryption as well.

master
Damir Jelić 2020-09-14 20:06:44 +02:00
parent 51f3d90224
commit 2d6882c495
3 changed files with 43 additions and 48 deletions

View File

@ -12,10 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{ use std::{collections::BTreeMap, io::Read};
collections::BTreeMap,
io::{Read, Write},
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -24,7 +21,7 @@ use matrix_sdk_common::events::room::JsonWebKey;
use getrandom::getrandom; use getrandom::getrandom;
use aes_ctr::{ use aes_ctr::{
stream_cipher::{NewStreamCipher, SyncStreamCipher, SyncStreamCipherSeek}, stream_cipher::{NewStreamCipher, SyncStreamCipher},
Aes256Ctr, Aes256Ctr,
}; };
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
@ -35,6 +32,7 @@ const IV_SIZE: usize = 16;
const KEY_SIZE: usize = 32; const KEY_SIZE: usize = 32;
const VERSION: u8 = 1; const VERSION: u8 = 1;
#[allow(missing_docs)]
pub struct AttachmentDecryptor<'a, R: 'a + Read> { pub struct AttachmentDecryptor<'a, R: 'a + Read> {
inner_reader: &'a mut R, inner_reader: &'a mut R,
expected_hash: Vec<u8>, expected_hash: Vec<u8>,
@ -47,7 +45,8 @@ impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
let read_bytes = self.inner_reader.read(buf)?; let read_bytes = self.inner_reader.read(buf)?;
if read_bytes == 0 { if read_bytes == 0 {
if self.sha.finalize_reset().as_slice() == self.expected_hash.as_slice() { let hash = self.sha.finalize_reset();
if hash.as_slice() == self.expected_hash.as_slice() {
Ok(0) Ok(0)
} else { } else {
panic!("INVALID HASH"); panic!("INVALID HASH");
@ -62,6 +61,7 @@ impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
} }
impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> { impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
#[allow(missing_docs)]
fn new(input: &'a mut R, info: EncryptionInfo) -> AttachmentDecryptor<'a, R> { fn new(input: &'a mut R, info: EncryptionInfo) -> AttachmentDecryptor<'a, R> {
// TODO check the version // TODO check the version
let hash = decode(info.hashes.get("sha256").unwrap()).unwrap(); let hash = decode(info.hashes.get("sha256").unwrap()).unwrap();
@ -81,9 +81,11 @@ impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
} }
} }
pub struct AttachmentEncryptor<'a, W: Write + 'a> { #[allow(missing_docs)]
#[derive(Debug)]
pub struct AttachmentEncryptor<'a, R: Read + 'a> {
finished: bool, finished: bool,
inner_writer: &'a mut W, inner_reader: &'a mut R,
web_key: JsonWebKey, web_key: JsonWebKey,
iv: String, iv: String,
hashes: BTreeMap<String, String>, hashes: BTreeMap<String, String>,
@ -91,36 +93,28 @@ pub struct AttachmentEncryptor<'a, W: Write + 'a> {
sha: Sha256, sha: Sha256,
} }
impl<'a, W: Write> Write for AttachmentEncryptor<'a, W> { impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if buf.is_empty() { let read_bytes = self.inner_reader.read(buf)?;
return Ok(0);
if read_bytes == 0 {
let hash = self.sha.finalize_reset();
self.hashes
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
Ok(0)
} else {
self.aes.apply_keystream(&mut buf[0..read_bytes]);
self.sha.update(&buf[0..read_bytes]);
Ok(read_bytes)
} }
// TODO avoid this allocation.
let mut buffer = buf.to_owned();
self.aes.apply_keystream(&mut buffer);
let written = self.inner_writer.write(&buffer)?;
self.sha.update(&buffer[0..written]);
// If we have written less than what was decrypted, seek/move our AES
// counter back.
let offset = buffer.len() - written;
let mut pos = self.aes.current_pos();
pos -= offset as u64;
self.aes.seek(pos);
Ok(written)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner_writer.flush()
} }
} }
impl<'a, W: Write + 'a> AttachmentEncryptor<'a, W> { impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
pub fn new(writer: &'a mut W) -> Self { #[allow(missing_docs)]
pub fn new(reader: &'a mut R) -> Self {
// TODO Use zeroizing here. // TODO Use zeroizing here.
let mut key = [0u8; KEY_SIZE]; let mut key = [0u8; KEY_SIZE];
let mut iv = [0u8; IV_SIZE]; let mut iv = [0u8; IV_SIZE];
@ -143,7 +137,7 @@ impl<'a, W: Write + 'a> AttachmentEncryptor<'a, W> {
AttachmentEncryptor { AttachmentEncryptor {
finished: false, finished: false,
inner_writer: writer, inner_reader: reader,
iv: encoded_iv, iv: encoded_iv,
web_key, web_key,
hashes: BTreeMap::new(), hashes: BTreeMap::new(),
@ -152,9 +146,12 @@ impl<'a, W: Write + 'a> AttachmentEncryptor<'a, W> {
} }
} }
#[allow(missing_docs)]
pub fn finish(mut self) -> EncryptionInfo { pub fn finish(mut self) -> EncryptionInfo {
let hash = self.sha.finalize(); let hash = self.sha.finalize();
self.hashes.insert("sha256".to_owned(), encode(hash)); self.hashes
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
EncryptionInfo { EncryptionInfo {
version: "v2".to_string(), version: "v2".to_string(),
@ -168,10 +165,10 @@ impl<'a, W: Write + 'a> AttachmentEncryptor<'a, W> {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct EncryptionInfo { pub struct EncryptionInfo {
#[serde(rename = "v")] #[serde(rename = "v")]
version: String, pub version: String,
web_key: JsonWebKey, pub web_key: JsonWebKey,
iv: String, pub iv: String,
hashes: BTreeMap<String, String>, pub hashes: BTreeMap<String, String>,
} }
#[cfg(test)] #[cfg(test)]
@ -207,20 +204,17 @@ mod test {
#[test] #[test]
fn encrypt_decrypt_cycle() { fn encrypt_decrypt_cycle() {
let data = "Hello world".to_owned(); let data = "Hello world".to_owned();
let mut cursor = Cursor::new(Vec::with_capacity(data.len())); let mut cursor = Cursor::new(data.clone());
let mut encryptor = AttachmentEncryptor::new(&mut cursor); let mut encryptor = AttachmentEncryptor::new(&mut cursor);
encryptor.write_all(data.as_bytes()).unwrap();
let key = encryptor.finish();
cursor.set_position(0);
let mut encrypted = Vec::new(); let mut encrypted = Vec::new();
cursor.read_to_end(&mut encrypted).unwrap();
cursor.set_position(0);
encryptor.read_to_end(&mut encrypted).unwrap();
let key = encryptor.finish();
assert_ne!(encrypted.as_slice(), data.as_bytes()); assert_ne!(encrypted.as_slice(), data.as_bytes());
let mut cursor = Cursor::new(encrypted);
let mut decryptor = AttachmentDecryptor::new(&mut cursor, key); let mut decryptor = AttachmentDecryptor::new(&mut cursor, key);
let mut decrypted_data = Vec::new(); let mut decrypted_data = Vec::new();

View File

@ -2,6 +2,7 @@
mod attachments; mod attachments;
mod key_export; mod key_export;
pub use attachments::AttachmentEncryptor;
pub use key_export::{decrypt_key_export, encrypt_key_export}; pub use key_export::{decrypt_key_export, encrypt_key_export};
use base64::{decode_config, encode_config, DecodeError, STANDARD_NO_PAD, URL_SAFE_NO_PAD}; use base64::{decode_config, encode_config, DecodeError, STANDARD_NO_PAD, URL_SAFE_NO_PAD};

View File

@ -37,7 +37,7 @@ pub mod store;
mod verification; mod verification;
pub use error::{MegolmError, OlmError}; pub use error::{MegolmError, OlmError};
pub use file_encryption::{decrypt_key_export, encrypt_key_export}; pub use file_encryption::{decrypt_key_export, encrypt_key_export, AttachmentEncryptor};
pub use identities::{ pub use identities::{
Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserDevices, UserIdentities, UserIdentity, Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserDevices, UserIdentities, UserIdentity,
}; };