Merge branch 'master' into new-state-store

master
Damir Jelić 2020-12-01 17:24:00 +01:00
commit 45442dfac8
13 changed files with 418 additions and 173 deletions

View File

@ -16,7 +16,7 @@
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use std::{collections::BTreeMap, io::Write, path::PathBuf}; use std::{collections::BTreeMap, io::Write, path::PathBuf};
use std::{ use std::{
convert::{TryFrom, TryInto}, convert::TryInto,
fmt::{self, Debug}, fmt::{self, Debug},
future::Future, future::Future,
io::Read, io::Read,
@ -803,7 +803,7 @@ impl Client {
since: Option<&str>, since: Option<&str>,
server: Option<&ServerName>, server: Option<&ServerName>,
) -> Result<get_public_rooms::Response> { ) -> Result<get_public_rooms::Response> {
let limit = limit.map(|n| UInt::try_from(n).ok()).flatten(); let limit = limit.map(|n| UInt::from(n));
let request = assign!(get_public_rooms::Request::new(), { let request = assign!(get_public_rooms::Request::new(), {
limit, limit,

View File

@ -47,9 +47,16 @@ pub enum OlmError {
Store(#[from] CryptoStoreError), Store(#[from] CryptoStoreError),
/// The session with a device has become corrupted. /// The session with a device has become corrupted.
#[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")] #[error(
"decryption failed likely because an Olm session from {0} with sender key {1} was wedged"
)]
SessionWedged(UserId, String), SessionWedged(UserId, String),
/// An Olm message got replayed while the Olm ratchet has already moved
/// forward.
#[error("decryption failed because an Olm message from {0} with sender key {1} was replayed")]
ReplayedMessage(UserId, String),
/// Encryption failed because the device does not have a valid Olm session /// Encryption failed because the device does not have a valid Olm session
/// with us. /// with us.
#[error( #[error(

View File

@ -14,10 +14,10 @@ fn decode_url_safe(input: impl AsRef<[u8]>) -> Result<Vec<u8>, DecodeError> {
decode_config(input, URL_SAFE_NO_PAD) decode_config(input, URL_SAFE_NO_PAD)
} }
fn encode(input: impl AsRef<[u8]>) -> String { pub fn encode(input: impl AsRef<[u8]>) -> String {
encode_config(input, STANDARD_NO_PAD) encode_config(input, STANDARD_NO_PAD)
} }
fn encode_url_safe(input: impl AsRef<[u8]>) -> String { pub fn encode_url_safe(input: impl AsRef<[u8]>) -> String {
encode_config(input, URL_SAFE_NO_PAD) encode_config(input, URL_SAFE_NO_PAD)
} }

View File

@ -1137,12 +1137,11 @@ mod test {
.unwrap() .unwrap()
.is_none()); .is_none());
let (_, decrypted, sender_key, _) = let decrypted = alice_account.decrypt_to_device_event(&event).await.unwrap();
alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.event.deserialize().unwrap() {
let (_, session) = alice_machine let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine alice_machine
@ -1157,7 +1156,11 @@ mod test {
// Check that alice now does have the session. // Check that alice now does have the session.
let session = alice_machine let session = alice_machine
.store .store
.get_inbound_group_session(&room_id(), &sender_key, group_session.session_id()) .get_inbound_group_session(
&room_id(),
&decrypted.sender_key,
group_session.session_id(),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
@ -1325,12 +1328,11 @@ mod test {
.unwrap() .unwrap()
.is_none()); .is_none());
let (_, decrypted, sender_key, _) = let decrypted = alice_account.decrypt_to_device_event(&event).await.unwrap();
alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.event.deserialize().unwrap() {
let (_, session) = alice_machine let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine alice_machine
@ -1345,7 +1347,11 @@ mod test {
// Check that alice now does have the session. // Check that alice now does have the session.
let session = alice_machine let session = alice_machine
.store .store
.get_inbound_group_session(&room_id(), &sender_key, group_session.session_id()) .get_inbound_group_session(
&room_id(),
&decrypted.sender_key,
group_session.session_id(),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();

View File

@ -52,7 +52,8 @@ use crate::{
key_request::KeyRequestMachine, key_request::KeyRequestMachine,
olm::{ olm::{
Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, ReadOnlyAccount,
SessionType,
}, },
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
session_manager::{GroupSessionManager, SessionManager}, session_manager::{GroupSessionManager, SessionManager},
@ -365,10 +366,15 @@ impl OlmMachine {
/// Mark the cross signing identity as shared. /// Mark the cross signing identity as shared.
async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> { async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> {
self.user_identity.lock().await.mark_as_shared(); let identity = self.user_identity.lock().await;
self.store identity.mark_as_shared();
.save_identity((&*self.user_identity.lock().await).clone())
.await let changes = Changes {
private_identity: Some(identity.clone()),
..Default::default()
};
self.store.save_changes(changes).await
} }
/// Create a new cross signing identity and get the upload request to push /// Create a new cross signing identity and get the upload request to push
@ -400,11 +406,12 @@ impl OlmMachine {
new: vec![public.into()], new: vec![public.into()],
..Default::default() ..Default::default()
}, },
private_identity: Some(identity.clone()),
..Default::default() ..Default::default()
}; };
self.store.save_changes(changes).await?; self.store.save_changes(changes).await?;
self.store.save_identity(identity.clone()).await?;
Ok((request, signature_request)) Ok((request, signature_request))
} else { } else {
info!("Trying to upload the existing cross signing identity"); info!("Trying to upload the existing cross signing identity");
@ -555,24 +562,23 @@ impl OlmMachine {
async fn decrypt_to_device_event( async fn decrypt_to_device_event(
&self, &self,
event: &ToDeviceEvent<EncryptedEventContent>, event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, Option<InboundGroupSession>)> { ) -> OlmResult<OlmDecryptionInfo> {
let (session, decrypted_event, sender_key, signing_key) = let mut decrypted = self.account.decrypt_to_device_event(event).await?;
self.account.decrypt_to_device_event(event).await?;
// Handle the decrypted event, e.g. fetch out Megolm sessions out of // Handle the decrypted event, e.g. fetch out Megolm sessions out of
// the event. // the event.
if let (Some(event), group_session) = self if let (Some(event), group_session) =
.handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event) self.handle_decrypted_to_device_event(&decrypted).await?
.await?
{ {
// Some events may have sensitive data e.g. private keys, while we // Some events may have sensitive data e.g. private keys, while we
// want to notify our users that a private key was received we // want to notify our users that a private key was received we
// don't want them to be able to do silly things with it. Handling // don't want them to be able to do silly things with it. Handling
// events modifies them and returns a modified one, so replace it // events modifies them and returns a modified one, so replace it
// here if we get one. // here if we get one.
Ok((session, event, group_session)) decrypted.event = event;
} else { decrypted.inbound_group_session = group_session;
Ok((session, decrypted_event, None))
} }
Ok(decrypted)
} }
/// Create a group session from a room key and add it to our crypto store. /// Create a group session from a room key and add it to our crypto store.
@ -704,27 +710,29 @@ impl OlmMachine {
/// * `event` - The decrypted to-device event. /// * `event` - The decrypted to-device event.
async fn handle_decrypted_to_device_event( async fn handle_decrypted_to_device_event(
&self, &self,
sender_key: &str, decrypted: &OlmDecryptionInfo,
signing_key: &str,
event: &Raw<AnyToDeviceEvent>,
) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> { ) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> {
let event = if let Ok(e) = event.deserialize() { let event = match decrypted.event.deserialize() {
e Ok(e) => e,
} else { Err(e) => {
warn!("Decrypted to-device event failed to be parsed correctly"); warn!(
return Ok((None, None)); "Decrypted to-device event failed to be parsed correctly {:?}",
e
);
return Ok((None, None));
}
}; };
match event { match event {
AnyToDeviceEvent::RoomKey(mut e) => { AnyToDeviceEvent::RoomKey(mut e) => Ok(self
Ok(self.add_room_key(sender_key, signing_key, &mut e).await?) .add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e)
} .await?),
AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self
.key_request_machine .key_request_machine
.receive_forwarded_room_key(sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await?), .await?),
_ => { _ => {
warn!("Received a unexpected encrypted to-device event"); warn!("Received an unexpected encrypted to-device event");
Ok((None, None)) Ok((None, None))
} }
} }
@ -808,38 +816,49 @@ impl OlmMachine {
match &mut event { match &mut event {
AnyToDeviceEvent::RoomEncrypted(e) => { AnyToDeviceEvent::RoomEncrypted(e) => {
let (session, decrypted_event, group_session) = let decrypted = match self.decrypt_to_device_event(e).await {
match self.decrypt_to_device_event(e).await { Ok(e) => e,
Ok(e) => e, Err(err) => {
Err(err) => { warn!(
warn!( "Failed to decrypt to-device event from {} {}",
"Failed to decrypt to-device event from {} {}", e.sender, err
e.sender, err );
);
if let OlmError::SessionWedged(sender, curve_key) = err { if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self if let Err(e) = self
.session_manager .session_manager
.mark_device_as_wedged(&sender, &curve_key) .mark_device_as_wedged(&sender, &curve_key)
.await .await
{ {
error!( error!(
"Couldn't mark device from {} to be unwedged {:?}", "Couldn't mark device from {} to be unwedged {:?}",
sender, e sender, e
); );
}
} }
continue;
} }
}; continue;
}
};
changes.sessions.push(session); // New sessions modify the account so we need to save that
// one as well.
match decrypted.session {
SessionType::New(s) => {
changes.sessions.push(s);
changes.account = Some(self.account.inner.clone());
}
SessionType::Existing(s) => {
changes.sessions.push(s);
}
}
if let Some(group_session) = group_session { changes.message_hashes.push(decrypted.message_hash);
if let Some(group_session) = decrypted.inbound_group_session {
changes.inbound_group_sessions.push(group_session); changes.inbound_group_sessions.push(group_session);
} }
*event_result = decrypted_event; *event_result = decrypted.event;
} }
AnyToDeviceEvent::RoomKeyRequest(e) => { AnyToDeviceEvent::RoomKeyRequest(e) => {
self.key_request_machine.receive_incoming_key_request(e) self.key_request_machine.receive_incoming_key_request(e)
@ -1283,8 +1302,11 @@ pub(crate) mod test {
content, content,
}; };
let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap(); let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store.save_sessions(&[session]).await.unwrap(); bob.store
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
(alice, bob) (alice, bob)
} }
@ -1578,7 +1600,7 @@ pub(crate) mod test {
.decrypt_to_device_event(&event) .decrypt_to_device_event(&event)
.await .await
.unwrap() .unwrap()
.1 .event
.deserialize() .deserialize()
.unwrap(); .unwrap();
@ -1614,14 +1636,17 @@ pub(crate) mod test {
.get_outbound_group_session(&room_id) .get_outbound_group_session(&room_id)
.unwrap(); .unwrap();
let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store.save_sessions(&[session]).await.unwrap();
bob.store bob.store
.save_inbound_group_sessions(&[group_session.unwrap()]) .save_sessions(&[decrypted.session.session()])
.await .await
.unwrap(); .unwrap();
let event = event.deserialize().unwrap(); bob.store
.save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()])
.await
.unwrap();
let event = decrypted.event.deserialize().unwrap();
if let AnyToDeviceEvent::RoomKey(event) = event { if let AnyToDeviceEvent::RoomKey(event) = event {
assert_eq!(&event.sender, alice.user_id()); assert_eq!(&event.sender, alice.user_id());
@ -1661,7 +1686,11 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests), content: to_device_requests_to_content(to_device_requests),
}; };
let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); let group_session = bob
.decrypt_to_device_event(&event)
.await
.unwrap()
.inbound_group_session;
bob.store bob.store
.save_inbound_group_sessions(&[group_session.unwrap()]) .save_inbound_group_sessions(&[group_session.unwrap()])
.await .await

View File

@ -15,6 +15,7 @@
use matrix_sdk_common::events::ToDeviceEvent; use matrix_sdk_common::events::ToDeviceEvent;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
@ -53,9 +54,10 @@ use olm_rs::{
use crate::{ use crate::{
error::{EventError, OlmResult, SessionCreationError}, error::{EventError, OlmResult, SessionCreationError},
file_encryption::encode,
identities::ReadOnlyDevice, identities::ReadOnlyDevice,
requests::UploadSigningKeysRequest, requests::UploadSigningKeysRequest,
store::Store, store::{Changes, Store},
OlmError, OlmError,
}; };
@ -70,6 +72,43 @@ pub struct Account {
pub(crate) store: Store, pub(crate) store: Store,
} }
#[derive(Debug, Clone)]
pub enum SessionType {
New(Session),
Existing(Session),
}
impl SessionType {
#[cfg(test)]
pub fn session(self) -> Session {
match self {
SessionType::New(s) => s,
SessionType::Existing(s) => s,
}
}
}
#[derive(Debug, Clone)]
pub struct OlmDecryptionInfo {
pub session: SessionType,
pub message_hash: OlmMessageHash,
pub event: Raw<AnyToDeviceEvent>,
pub signing_key: String,
pub sender_key: String,
pub inbound_group_session: Option<InboundGroupSession>,
}
/// A hash of a succesfully decrypted Olm message.
///
/// Can be used to check if a message has been replayed to us.
#[derive(Debug, Clone)]
pub struct OlmMessageHash {
/// The curve25519 key of the sender that sent us the Olm message.
pub sender_key: String,
/// The hash of the message.
pub hash: String,
}
impl Deref for Account { impl Deref for Account {
type Target = ReadOnlyAccount; type Target = ReadOnlyAccount;
@ -82,7 +121,7 @@ impl Account {
pub async fn decrypt_to_device_event( pub async fn decrypt_to_device_event(
&self, &self,
event: &ToDeviceEvent<EncryptedEventContent>, event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String, String)> { ) -> OlmResult<OlmDecryptionInfo> {
debug!("Decrypting to-device event"); debug!("Decrypting to-device event");
let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content {
@ -103,23 +142,47 @@ impl Account {
.try_into() .try_into()
.map_err(|_| EventError::UnsupportedOlmType)?; .map_err(|_| EventError::UnsupportedOlmType)?;
let sha = Sha256::new()
.chain(&content.sender_key)
.chain(&[message_type])
.chain(&ciphertext.body);
let message_hash = OlmMessageHash {
sender_key: content.sender_key.clone(),
hash: encode(sha.finalize().as_slice()),
};
// Create a OlmMessage from the ciphertext and the type. // Create a OlmMessage from the ciphertext and the type.
let message = let message =
OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone()) OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone())
.map_err(|_| EventError::UnsupportedOlmType)?; .map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it. // Decrypt the OlmMessage and get a Ruma event out of it.
let (session, decrypted_event, signing_key) = self let (session, event, signing_key) = match self
.decrypt_olm_message(&event.sender, &content.sender_key, message) .decrypt_olm_message(&event.sender, &content.sender_key, message)
.await?; .await
{
Ok(d) => d,
Err(OlmError::SessionWedged(user_id, sender_key)) => {
if self.store.is_message_known(&message_hash).await? {
return Err(OlmError::ReplayedMessage(user_id, sender_key));
} else {
return Err(OlmError::SessionWedged(user_id, sender_key));
}
}
Err(e) => return Err(e.into()),
};
debug!("Decrypted a to-device event {:?}", decrypted_event); debug!("Decrypted a to-device event {:?}", event);
Ok((
Ok(OlmDecryptionInfo {
session, session,
decrypted_event, message_hash,
content.sender_key.clone(), event,
signing_key, signing_key,
)) sender_key: content.sender_key.clone(),
inbound_group_session: None,
})
} else { } else {
warn!("Olm event doesn't contain a ciphertext for our key"); warn!("Olm event doesn't contain a ciphertext for our key");
Err(EventError::MissingCiphertext.into()) Err(EventError::MissingCiphertext.into())
@ -227,7 +290,7 @@ impl Account {
sender: &UserId, sender: &UserId,
sender_key: &str, sender_key: &str,
message: OlmMessage, message: OlmMessage,
) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String)> { ) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session. // First try to decrypt using an existing session.
let (session, plaintext) = if let Some(d) = self let (session, plaintext) = if let Some(d) = self
.try_decrypt_olm_message(sender, sender_key, &message) .try_decrypt_olm_message(sender, sender_key, &message)
@ -235,7 +298,7 @@ impl Account {
{ {
// Decryption succeeded, de-structure the session/plaintext out of // Decryption succeeded, de-structure the session/plaintext out of
// the Option. // the Option.
d (SessionType::Existing(d.0), d.1)
} else { } else {
// Decryption failed with every known session, let's try to create a // Decryption failed with every known session, let's try to create a
// new session. // new session.
@ -282,7 +345,7 @@ impl Account {
// Decrypt our message, this shouldn't fail since we're using a // Decrypt our message, this shouldn't fail since we're using a
// newly created Session. // newly created Session.
let plaintext = session.decrypt(message).await?; let plaintext = session.decrypt(message).await?;
(session, plaintext) (SessionType::New(session), plaintext)
}; };
trace!("Successfully decrypted a Olm message: {}", plaintext); trace!("Successfully decrypted a Olm message: {}", plaintext);
@ -293,7 +356,20 @@ impl Account {
// We might created a new session but decryption might still // We might created a new session but decryption might still
// have failed, store it for the error case here, this is fine // have failed, store it for the error case here, this is fine
// since we don't expect this to happen often or at all. // since we don't expect this to happen often or at all.
self.store.save_sessions(&[session]).await?; match session {
SessionType::New(s) => {
let changes = Changes {
account: Some(self.inner.clone()),
sessions: vec![s],
..Default::default()
};
self.store.save_changes(changes).await?;
}
SessionType::Existing(s) => {
self.store.save_sessions(&[s]).await?;
}
}
return Err(e); return Err(e);
} }
}; };

View File

@ -23,8 +23,8 @@ mod session;
mod signing; mod signing;
mod utility; mod utility;
pub(crate) use account::Account; pub(crate) use account::{Account, OlmDecryptionInfo, SessionType};
pub use account::{AccountPickle, PickledAccount, ReadOnlyAccount}; pub use account::{AccountPickle, OlmMessageHash, PickledAccount, ReadOnlyAccount};
pub use group_sessions::{ pub use group_sessions::{
EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle, EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle,
PickledInboundGroupSession, PickledInboundGroupSession,

View File

@ -126,9 +126,7 @@ impl Session {
"content": content, "content": content,
}); });
let plaintext = serde_json::to_string(&payload) let plaintext = serde_json::to_string(&payload)?;
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload)));
let ciphertext = self.encrypt_helper(&plaintext).await.to_tuple(); let ciphertext = self.encrypt_helper(&plaintext).await.to_tuple();
let message_type = ciphertext.0; let message_type = ciphertext.0;

View File

@ -170,7 +170,6 @@ impl MasterSigning {
} }
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) { pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
// TODO create a borrowed version of a cross singing key.
let subkey_wihtout_signatures = json!({ let subkey_wihtout_signatures = json!({
"user_id": subkey.user_id.clone(), "user_id": subkey.user_id.clone(),
"keys": subkey.keys.clone(), "keys": subkey.keys.clone(),

View File

@ -40,6 +40,7 @@ pub struct MemoryStore {
inbound_group_sessions: GroupSessionStore, inbound_group_sessions: GroupSessionStore,
tracked_users: Arc<DashSet<UserId>>, tracked_users: Arc<DashSet<UserId>>,
users_for_key_query: Arc<DashSet<UserId>>, users_for_key_query: Arc<DashSet<UserId>>,
olm_hashes: Arc<DashMap<String, DashSet<String>>>,
devices: DeviceStore, devices: DeviceStore,
identities: Arc<DashMap<UserId, UserIdentities>>, identities: Arc<DashMap<UserId, UserIdentities>>,
values: Arc<DashMap<String, String>>, values: Arc<DashMap<String, String>>,
@ -52,6 +53,7 @@ impl Default for MemoryStore {
inbound_group_sessions: GroupSessionStore::new(), inbound_group_sessions: GroupSessionStore::new(),
tracked_users: Arc::new(DashSet::new()), tracked_users: Arc::new(DashSet::new()),
users_for_key_query: Arc::new(DashSet::new()), users_for_key_query: Arc::new(DashSet::new()),
olm_hashes: Arc::new(DashMap::new()),
devices: DeviceStore::new(), devices: DeviceStore::new(),
identities: Arc::new(DashMap::new()), identities: Arc::new(DashMap::new()),
values: Arc::new(DashMap::new()), values: Arc::new(DashMap::new()),
@ -120,6 +122,13 @@ impl CryptoStore for MemoryStore {
.insert(identity.user_id().to_owned(), identity.clone()); .insert(identity.user_id().to_owned(), identity.clone());
} }
for hash in changes.message_hashes {
self.olm_hashes
.entry(hash.sender_key.to_owned())
.or_insert_with(DashSet::new)
.insert(hash.hash.clone());
}
Ok(()) Ok(())
} }
@ -211,21 +220,25 @@ impl CryptoStore for MemoryStore {
Ok(self.values.get(key).map(|v| v.to_owned())) Ok(self.values.get(key).map(|v| v.to_owned()))
} }
async fn save_identity(&self, _: PrivateCrossSigningIdentity) -> Result<()> {
Ok(())
}
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> { async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
Ok(None) Ok(None)
} }
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
Ok(self
.olm_hashes
.entry(message_hash.sender_key.to_owned())
.or_insert_with(DashSet::new)
.contains(&message_hash.hash))
}
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ use crate::{
identities::device::test::get_device, identities::device::test::get_device,
olm::{test::get_account_and_session, InboundGroupSession}, olm::{test::get_account_and_session, InboundGroupSession, OlmMessageHash},
store::{memorystore::MemoryStore, CryptoStore}, store::{memorystore::MemoryStore, Changes, CryptoStore},
}; };
use matrix_sdk_common::identifiers::room_id; use matrix_sdk_common::identifiers::room_id;
@ -329,4 +342,21 @@ mod test {
assert!(store.is_user_tracked(device.user_id())); assert!(store.is_user_tracked(device.user_id()));
} }
#[tokio::test]
async fn test_message_hash() {
let store = MemoryStore::new();
let hash = OlmMessageHash {
sender_key: "test_sender".to_owned(),
hash: "test_hash".to_owned(),
};
let mut changes = Changes::default();
changes.message_hashes.push(hash.clone());
assert!(!store.is_message_known(&hash).await.unwrap());
store.save_changes(changes).await.unwrap();
assert!(store.is_message_known(&hash).await.unwrap());
}
} }

View File

@ -82,7 +82,9 @@ use matrix_sdk_common_macros::send_sync;
use crate::{ use crate::{
error::SessionUnpicklingError, error::SessionUnpicklingError,
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
olm::{InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session}, olm::{
InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, ReadOnlyAccount, Session,
},
verification::VerificationMachine, verification::VerificationMachine,
}; };
@ -107,7 +109,9 @@ pub(crate) struct Store {
#[allow(missing_docs)] #[allow(missing_docs)]
pub struct Changes { pub struct Changes {
pub account: Option<ReadOnlyAccount>, pub account: Option<ReadOnlyAccount>,
pub private_identity: Option<PrivateCrossSigningIdentity>,
pub sessions: Vec<Session>, pub sessions: Vec<Session>,
pub message_hashes: Vec<OlmMessageHash>,
pub inbound_group_sessions: Vec<InboundGroupSession>, pub inbound_group_sessions: Vec<InboundGroupSession>,
pub identities: IdentityChanges, pub identities: IdentityChanges,
pub devices: DeviceChanges, pub devices: DeviceChanges,
@ -342,13 +346,14 @@ pub trait CryptoStore: Debug {
/// * `account` - The account that should be stored. /// * `account` - The account that should be stored.
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>;
/// TODO /// Try to load a private cross signing identity, if one is stored.
async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()>;
/// TODO
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>>; async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>>;
/// TODO /// Save the set of changes to the store.
///
/// # Arguments
///
/// * `changes` - The set of changes that should be stored.
async fn save_changes(&self, changes: Changes) -> Result<()>; async fn save_changes(&self, changes: Changes) -> Result<()>;
/// Get all the sessions that belong to the given sender key. /// Get all the sessions that belong to the given sender key.
@ -435,4 +440,7 @@ pub trait CryptoStore: Debug {
/// Load a serializeable object from the store. /// Load a serializeable object from the store.
async fn get_value(&self, key: &str) -> Result<Option<String>>; async fn get_value(&self, key: &str) -> Result<Option<String>>;
/// Check if a hash for an Olm message stored in the database.
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>;
} }

View File

@ -42,8 +42,9 @@ use crate::{
identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity},
olm::{ olm::{
AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle, AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle,
PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession, PickledSession, OlmMessageHash, PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession,
PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, SessionPickle, PickledSession, PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session,
SessionPickle,
}, },
}; };
@ -491,6 +492,24 @@ impl SqliteStore {
) )
.await?; .await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS olm_hashes (
"id" INTEGER NOT NULL PRIMARY KEY,
"account_id" INTEGER NOT NULL,
"sender_key" TEXT NOT NULL,
"hash" TEXT NOT NULL,
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
ON DELETE CASCADE
UNIQUE(account_id,sender_key,hash)
);
CREATE INDEX IF NOT EXISTS "olm_hashes_index" ON "olm_hashes" ("account_id");
"#,
)
.await?;
Ok(()) Ok(())
} }
@ -1466,6 +1485,92 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn save_olm_hashses(
&self,
connection: &mut SqliteConnection,
hashes: &[OlmMessageHash],
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for hash in hashes {
query("REPLACE INTO olm_hashes (account_id, sender_key, hash) VALUES (?1, ?2, ?3)")
.bind(account_id)
.bind(&hash.sender_key)
.bind(&hash.hash)
.execute(&mut *connection)
.await?;
}
Ok(())
}
async fn save_identity(
&self,
connection: &mut SqliteConnection,
identity: PrivateCrossSigningIdentity,
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let pickle = identity.pickle(self.get_pickle_key()).await?;
query(
"INSERT INTO private_identities (
account_id, user_id, pickle, shared
) VALUES (?1, ?2, ?3, ?4)
ON CONFLICT(account_id, user_id) DO UPDATE SET
pickle = excluded.pickle,
shared = excluded.shared
",
)
.bind(account_id)
.bind(pickle.user_id.as_str())
.bind(pickle.pickle)
.bind(pickle.shared)
.execute(&mut *connection)
.await?;
Ok(())
}
async fn save_account_helper(
&self,
connection: &mut SqliteConnection,
account: ReadOnlyAccount,
) -> Result<()> {
let pickle = account.pickle(self.get_pickle_mode()).await;
query(
"INSERT INTO accounts (
user_id, device_id, pickle, shared, uploaded_key_count
) VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT(user_id, device_id) DO UPDATE SET
pickle = excluded.pickle,
shared = excluded.shared,
uploaded_key_count = excluded.uploaded_key_count
",
)
.bind(pickle.user_id.as_str())
.bind(pickle.device_id.as_str())
.bind(pickle.pickle.as_str())
.bind(pickle.shared)
.bind(pickle.uploaded_signed_key_count)
.execute(&mut *connection)
.await?;
let account_id: (i64,) =
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
.bind(self.user_id.as_str())
.bind(self.device_id.as_str())
.fetch_one(&mut *connection)
.await?;
*self.account_info.lock().unwrap() = Some(AccountInfo {
account_id: account_id.0,
identity_keys: account.identity_keys.clone(),
});
Ok(())
}
async fn save_user_helper( async fn save_user_helper(
&self, &self,
mut connection: &mut SqliteConnection, mut connection: &mut SqliteConnection,
@ -1569,65 +1674,7 @@ impl CryptoStore for SqliteStore {
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let pickle = account.pickle(self.get_pickle_mode()).await; self.save_account_helper(&mut connection, account).await
query(
"INSERT INTO accounts (
user_id, device_id, pickle, shared, uploaded_key_count
) VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT(user_id, device_id) DO UPDATE SET
pickle = excluded.pickle,
shared = excluded.shared,
uploaded_key_count = excluded.uploaded_key_count
",
)
.bind(pickle.user_id.as_str())
.bind(pickle.device_id.as_str())
.bind(pickle.pickle.as_str())
.bind(pickle.shared)
.bind(pickle.uploaded_signed_key_count)
.execute(&mut *connection)
.await?;
let account_id: (i64,) =
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
.bind(self.user_id.as_str())
.bind(self.device_id.as_str())
.fetch_one(&mut *connection)
.await?;
*self.account_info.lock().unwrap() = Some(AccountInfo {
account_id: account_id.0,
identity_keys: account.identity_keys.clone(),
});
Ok(())
}
async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let pickle = identity.pickle(self.get_pickle_key()).await?;
let mut connection = self.connection.lock().await;
query(
"INSERT INTO private_identities (
account_id, user_id, pickle, shared
) VALUES (?1, ?2, ?3, ?4)
ON CONFLICT(account_id, user_id) DO UPDATE SET
pickle = excluded.pickle,
shared = excluded.shared
",
)
.bind(account_id)
.bind(pickle.user_id.as_str())
.bind(pickle.pickle)
.bind(pickle.shared)
.execute(&mut *connection)
.await?;
Ok(())
} }
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> { async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
@ -1664,6 +1711,14 @@ impl CryptoStore for SqliteStore {
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?; let mut transaction = connection.begin().await?;
if let Some(account) = changes.account {
self.save_account_helper(&mut transaction, account).await?;
}
if let Some(identity) = changes.private_identity {
self.save_identity(&mut transaction, identity).await?;
}
self.save_sessions_helper(&mut transaction, &changes.sessions) self.save_sessions_helper(&mut transaction, &changes.sessions)
.await?; .await?;
self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions) self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions)
@ -1680,6 +1735,8 @@ impl CryptoStore for SqliteStore {
.await?; .await?;
self.save_user_identities(&mut transaction, &changes.identities.changed) self.save_user_identities(&mut transaction, &changes.identities.changed)
.await?; .await?;
self.save_olm_hashses(&mut transaction, &changes.message_hashes)
.await?;
transaction.commit().await?; transaction.commit().await?;
@ -1796,6 +1853,22 @@ impl CryptoStore for SqliteStore {
Ok(row.map(|r| r.0)) Ok(row.map(|r| r.0))
} }
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let row: Option<(String,)> = query_as(
"SELECT hash FROM olm_hashes WHERE account_id = ? and sender_key = ? and hash = ?",
)
.bind(account_id)
.bind(&message_hash.sender_key)
.bind(&message_hash.hash)
.fetch_optional(&mut *connection)
.await?;
Ok(row.is_some())
}
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
@ -1817,8 +1890,8 @@ mod test {
user::test::{get_other_identity, get_own_identity}, user::test::{get_other_identity, get_own_identity},
}, },
olm::{ olm::{
GroupSessionKey, InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity,
Session, ReadOnlyAccount, Session,
}, },
store::{Changes, DeviceChanges, IdentityChanges}, store::{Changes, DeviceChanges, IdentityChanges},
}; };
@ -2352,7 +2425,12 @@ mod test {
assert!(store.load_identity().await.unwrap().is_none()); assert!(store.load_identity().await.unwrap().is_none());
let identity = PrivateCrossSigningIdentity::new((&*store.user_id).clone()).await; let identity = PrivateCrossSigningIdentity::new((&*store.user_id).clone()).await;
store.save_identity(identity.clone()).await.unwrap(); let changes = Changes {
private_identity: Some(identity.clone()),
..Default::default()
};
store.save_changes(changes).await.unwrap();
let loaded_identity = store.load_identity().await.unwrap().unwrap(); let loaded_identity = store.load_identity().await.unwrap().unwrap();
assert_eq!(identity.user_id(), loaded_identity.user_id()); assert_eq!(identity.user_id(), loaded_identity.user_id());
} }
@ -2371,4 +2449,21 @@ mod test {
store.remove_value(&key).await.unwrap(); store.remove_value(&key).await.unwrap();
assert!(store.get_value(&key).await.unwrap().is_none()); assert!(store.get_value(&key).await.unwrap().is_none());
} }
#[tokio::test(threaded_scheduler)]
async fn olm_hash_saving() {
let (_, store, _dir) = get_loaded_store().await;
let hash = OlmMessageHash {
sender_key: "test_sender".to_owned(),
hash: "test_hash".to_owned(),
};
let mut changes = Changes::default();
changes.message_hashes.push(hash.clone());
assert!(!store.is_message_known(&hash).await.unwrap());
store.save_changes(changes).await.unwrap();
assert!(store.is_message_known(&hash).await.unwrap());
}
} }

View File

@ -117,9 +117,6 @@ impl Default for AcceptedProtocols {
} }
} }
// TODO implement expiration of the verification flow using the timeouts defined
// in the spec.
/// A type level state machine modeling the Sas flow. /// A type level state machine modeling the Sas flow.
/// ///
/// This is the generic struc holding common data between the different states /// This is the generic struc holding common data between the different states