Merge branch 'master' of https://github.com/matrix-org/matrix-rust-sdk into messages

master
Devin R 2020-05-05 06:57:37 -04:00
commit bfa9c0fda9
14 changed files with 653 additions and 321 deletions

View File

@ -23,39 +23,39 @@ http = "0.2.1"
url = "2.1.1" url = "2.1.1"
async-trait = "0.1.30" async-trait = "0.1.30"
serde = "1.0.106" serde = "1.0.106"
serde_json = "1.0.51" serde_json = "1.0.52"
uuid = { version = "0.8.1", features = ["v4"] } uuid = { version = "0.8.1", features = ["v4"] }
matrix-sdk-types = { path = "../matrix_sdk_types" } matrix-sdk-types = { path = "../matrix_sdk_types" }
matrix-sdk-crypto = { path = "../matrix_sdk_crypto", optional = true } matrix-sdk-crypto = { path = "../matrix_sdk_crypto", optional = true }
# Misc dependencies # Misc dependencies
thiserror = "1.0.14" thiserror = "1.0.16"
tracing = "0.1.13" tracing = "0.1.13"
atomic = "0.4.5" atomic = "0.4.5"
dashmap = "3.10.0" dashmap = "3.11.1"
[dependencies.tracing-futures] [dependencies.tracing-futures]
version = "0.2.3" version = "0.2.4"
default-features = false default-features = false
features = ["std", "std-future"] features = ["std", "std-future"]
[dependencies.tokio] [dependencies.tokio]
version = "0.2.16" version = "0.2.20"
default-features = false default-features = false
features = ["sync", "time", "fs"] features = ["sync", "time", "fs"]
[dependencies.sqlx] [dependencies.sqlx]
version = "0.3.3" version = "0.3.4"
optional = true optional = true
default-features = false default-features = false
features = ["runtime-tokio", "sqlite"] features = ["runtime-tokio", "sqlite"]
[dev-dependencies] [dev-dependencies]
tokio = { version = "0.2.16", features = ["rt-threaded", "macros"] } tokio = { version = "0.2.20", features = ["rt-threaded", "macros"] }
ruma-identifiers = { version = "0.16.0", features = ["rand"] } ruma-identifiers = { version = "0.16.1", features = ["rand"] }
serde_json = "1.0.51" serde_json = "1.0.52"
tracing-subscriber = "0.2.4" tracing-subscriber = "0.2.5"
tempfile = "3.1.0" tempfile = "3.1.0"
mockito = "0.25.1" mockito = "0.25.1"
lazy_static = "1.4.0" lazy_static = "1.4.0"

View File

@ -111,7 +111,7 @@ impl Client {
pub fn new(session: Option<Session>) -> Result<Self> { pub fn new(session: Option<Session>) -> Result<Self> {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
let olm = match &session { let olm = match &session {
Some(s) => Some(OlmMachine::new(&s.user_id, &s.device_id)?), Some(s) => Some(OlmMachine::new(&s.user_id, &s.device_id)),
None => None, None => None,
}; };
@ -199,7 +199,7 @@ impl Client {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
{ {
let mut olm = self.olm.lock().await; let mut olm = self.olm.lock().await;
*olm = Some(OlmMachine::new(&response.user_id, &response.device_id)?); *olm = Some(OlmMachine::new(&response.user_id, &response.device_id));
} }
Ok(()) Ok(())
@ -261,12 +261,15 @@ impl Client {
/// Returns true if the room name changed, false otherwise. /// Returns true if the room name changed, false otherwise.
pub(crate) fn handle_push_rules(&mut self, event: &PushRulesEvent) -> bool { pub(crate) fn handle_push_rules(&mut self, event: &PushRulesEvent) -> bool {
// TODO this is basically a stub // TODO this is basically a stub
if self.push_ruleset.as_ref() == Some(&event.content.global) { // TODO ruma removed PartialEq for evens, so this doesn't work anymore.
false // Returning always true for now should be ok here since those don't
} else { // change often.
self.push_ruleset = Some(event.content.global.clone()); // if self.push_ruleset.as_ref() == Some(&event.content.global) {
true // false
} // } else {
self.push_ruleset = Some(event.content.global.clone());
true
// }
} }
/// Receive a timeline event for a joined room and update the client state. /// Receive a timeline event for a joined room and update the client state.
@ -294,16 +297,13 @@ impl Client {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
{ {
match e { if let RoomEvent::RoomEncrypted(ref mut e) = e {
RoomEvent::RoomEncrypted(ref mut e) => { e.room_id = Some(room_id.to_owned());
e.room_id = Some(room_id.to_owned()); let mut olm = self.olm.lock().await;
let mut olm = self.olm.lock().await;
if let Some(o) = &mut *olm { if let Some(o) = &mut *olm {
decrypted_event = o.decrypt_room_event(&e).await.ok(); decrypted_event = o.decrypt_room_event(&e).await.ok();
}
} }
_ => (),
} }
} }
@ -535,12 +535,15 @@ impl Client {
) -> Result<MessageEventContent> { ) -> Result<MessageEventContent> {
let mut olm = self.olm.lock().await; let mut olm = self.olm.lock().await;
match &mut *olm { // TODO enable this again once we can send encrypted event
Some(o) => Ok(MessageEventContent::Encrypted( // contents with ruma.
o.encrypt(room_id, content).await?, // match &mut *olm {
)), // Some(o) => Ok(MessageEventContent::Encrypted(
None => panic!("Olm machine wasn't started"), // o.encrypt(room_id, content).await?,
} // )),
// None => panic!("Olm machine wasn't started"),
// }
Ok(content)
} }
/// Get a tuple of device and one-time keys that need to be uploaded. /// Get a tuple of device and one-time keys that need to be uploaded.

View File

@ -26,7 +26,7 @@ use thiserror::Error;
use url::ParseError; use url::ParseError;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_crypto::OlmError; use matrix_sdk_crypto::{MegolmError, OlmError};
/// Result type of the rust-sdk. /// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -59,6 +59,9 @@ pub enum Error {
/// An error occurred during a E2EE operation. /// An error occurred during a E2EE operation.
#[error(transparent)] #[error(transparent)]
OlmError(#[from] OlmError), OlmError(#[from] OlmError),
/// An error occurred during a E2EE group operation.
#[error(transparent)]
MegolmError(#[from] MegolmError),
} }
impl From<RumaResponseError<RumaClientError>> for Error { impl From<RumaResponseError<RumaClientError>> for Error {

View File

@ -31,7 +31,7 @@ use crate::{Result, Room, Session};
/// When implementing `StateStore` for something other than the filesystem /// When implementing `StateStore` for something other than the filesystem
/// implement `From<ClientState> for YourDbType` this allows for easy conversion /// implement `From<ClientState> for YourDbType` this allows for easy conversion
/// when needed in `StateStore::load/store_client_state` /// when needed in `StateStore::load/store_client_state`
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ClientState { pub struct ClientState {
/// The current sync token that should be used for the next sync call. /// The current sync token that should be used for the next sync call.
pub sync_token: Option<Token>, pub sync_token: Option<Token>,
@ -41,6 +41,12 @@ pub struct ClientState {
pub push_ruleset: Option<Ruleset>, pub push_ruleset: Option<Ruleset>,
} }
impl PartialEq for ClientState {
fn eq(&self, other: &Self) -> bool {
self.sync_token == other.sync_token && self.ignored_users == other.ignored_users
}
}
impl ClientState { impl ClientState {
pub fn from_base_client(client: &BaseClient) -> ClientState { pub fn from_base_client(client: &BaseClient) -> ClientState {
let BaseClient { let BaseClient {

View File

@ -101,8 +101,8 @@ impl Device {
} }
/// Get the key of the given key algorithm belonging to this device. /// Get the key of the given key algorithm belonging to this device.
pub fn get_key(&self, algorithm: &KeyAlgorithm) -> Option<&String> { pub fn get_key(&self, algorithm: KeyAlgorithm) -> Option<&String> {
self.keys.get(algorithm) self.keys.get(&algorithm)
} }
/// Get a map containing all the device keys. /// Get a map containing all the device keys.
@ -274,11 +274,11 @@ pub(crate) mod test {
device.display_name().as_ref().unwrap() device.display_name().as_ref().unwrap()
); );
assert_eq!( assert_eq!(
device.get_key(&KeyAlgorithm::Curve25519).unwrap(), device.get_key(KeyAlgorithm::Curve25519).unwrap(),
"wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4" "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4"
); );
assert_eq!( assert_eq!(
device.get_key(&KeyAlgorithm::Ed25519).unwrap(), device.get_key(KeyAlgorithm::Ed25519).unwrap(),
"nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM" "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM"
); );
} }

View File

@ -19,46 +19,101 @@ use thiserror::Error;
use super::store::CryptoStoreError; use super::store::CryptoStoreError;
pub type Result<T> = std::result::Result<T, OlmError>; pub type OlmResult<T> = Result<T, OlmError>;
pub type MegolmResult<T> = Result<T, MegolmError>;
/// Error representing a failure during a device to device cryptographic
/// operation.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum OlmError { pub enum OlmError {
#[error("signature verification failed")] /// The event that should have been decrypted is malformed.
Signature(#[from] SignatureError), #[error(transparent)]
#[error("failed to read or write to the crypto store {0}")] EventError(#[from] EventError),
Store(#[from] CryptoStoreError),
#[error("decryption failed likely because a Olm session was wedged")] /// The received decrypted event couldn't be deserialized.
SessionWedged, #[error(transparent)]
#[error("the Olm message has a unsupported type")] JsonError(#[from] SerdeError),
UnsupportedOlmType,
#[error("the Encrypted message has been encrypted with a unsupported algorithm.")] /// The underlying Olm session operation returned an error.
UnsupportedAlgorithm,
#[error("the Encrypted message doesn't contain a ciphertext for our device")]
MissingCiphertext,
#[error("decryption failed because the session to decrypt the message is missing")]
MissingSession,
#[error("the Encrypted message is missing the signing key of the sender")]
MissingSigningKey,
#[error("can't finish Olm Session operation {0}")] #[error("can't finish Olm Session operation {0}")]
OlmSession(#[from] OlmSessionError), OlmSession(#[from] OlmSessionError),
/// The underlying group session operation returned an error.
#[error("can't finish Olm Session operation {0}")] #[error("can't finish Olm Session operation {0}")]
OlmGroupSession(#[from] OlmGroupSessionError), OlmGroupSession(#[from] OlmGroupSessionError),
#[error("error deserializing a string to json")]
JsonError(#[from] SerdeError), /// The storage layer returned an error.
#[error("the provided JSON value isn't an object")] #[error("failed to read or write to the crypto store {0}")]
NotAnObject, Store(#[from] CryptoStoreError),
/// The session with a device has become corrupted.
#[error("decryption failed likely because a Olm session was wedged")]
SessionWedged,
} }
pub type VerificationResult<T> = std::result::Result<T, SignatureError>; /// Error representing a failure during a group encryption operation.
#[derive(Error, Debug)]
pub enum MegolmError {
/// The event that should have been decrypted is malformed.
#[error(transparent)]
EventError(#[from] EventError),
/// The received decrypted event couldn't be deserialized.
#[error(transparent)]
JsonError(#[from] SerdeError),
/// Decryption failed because the session needed to decrypt the event is
/// missing.
#[error("decryption failed because the session to decrypt the message is missing")]
MissingSession,
/// The underlying group session operation returned an error.
#[error("can't finish Olm group session operation {0}")]
OlmGroupSession(#[from] OlmGroupSessionError),
/// The storage layer returned an error.
#[error(transparent)]
Store(#[from] CryptoStoreError),
}
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum SignatureError { pub enum EventError {
#[error("the Olm message has a unsupported type")]
UnsupportedOlmType,
#[error("the Encrypted message has been encrypted with a unsupported algorithm.")]
UnsupportedAlgorithm,
#[error("the provided JSON value isn't an object")] #[error("the provided JSON value isn't an object")]
NotAnObject, NotAnObject,
#[error("the Encrypted message doesn't contain a ciphertext for our device")]
MissingCiphertext,
#[error("the Encrypted message is missing the signing key of the sender")]
MissingSigningKey,
#[error("the Encrypted message is missing the field {0}")]
MissingField(String),
#[error("the sender of the plaintext doesn't match the sender of the encrypted message.")]
MissmatchedSender,
#[error("the keys of the message don't match the keys in our database.")]
MissmatchedKeys,
}
#[derive(Error, Debug)]
pub(crate) enum SignatureError {
#[error("the provided JSON value isn't an object")]
NotAnObject,
#[error("the provided JSON object doesn't contain a signatures field")] #[error("the provided JSON object doesn't contain a signatures field")]
NoSignatureFound, NoSignatureFound,
#[error("the provided JSON object can't be converted to a canonical representation")] #[error("the provided JSON object can't be converted to a canonical representation")]
CanonicalJsonError(CjsonError), CanonicalJsonError(CjsonError),
#[error("the signature didn't match the provided key")] #[error("the signature didn't match the provided key")]
VerificationError, VerificationError,
} }

View File

@ -15,6 +15,16 @@
//! This is the encryption part of the matrix-sdk. It contains a state machine //! This is the encryption part of the matrix-sdk. It contains a state machine
//! that will aid in adding encryption support to a client library. //! that will aid in adding encryption support to a client library.
#![deny(
missing_debug_implementations,
missing_docs,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
unused_qualifications
)]
mod device; mod device;
mod error; mod error;
mod machine; mod machine;
@ -23,6 +33,10 @@ mod olm;
mod store; mod store;
pub use device::{Device, TrustState}; pub use device::{Device, TrustState};
pub use error::OlmError; pub use error::{MegolmError, OlmError};
pub use machine::{OlmMachine, OneTimeKeys}; pub use machine::{OlmMachine, OneTimeKeys};
pub use memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
pub use olm::{Account, InboundGroupSession, OutboundGroupSession, Session};
#[cfg(feature = "sqlite-cryptostore")]
pub use store::sqlite::SqliteStore;
pub use store::{CryptoStore, CryptoStoreError}; pub use store::{CryptoStore, CryptoStoreError};

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@ use super::olm::{InboundGroupSession, Session};
use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId}; use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId};
/// In-memory store for Olm Sessions. /// In-memory store for Olm Sessions.
#[derive(Debug)] #[derive(Debug, Default)]
pub struct SessionStore { pub struct SessionStore {
entries: HashMap<String, Arc<Mutex<Vec<Session>>>>, entries: HashMap<String, Arc<Mutex<Vec<Session>>>>,
} }
@ -69,7 +69,7 @@ impl SessionStore {
} }
} }
#[derive(Debug)] #[derive(Debug, Default)]
/// In-memory store that houlds inbound group sessions. /// In-memory store that houlds inbound group sessions.
pub struct GroupSessionStore { pub struct GroupSessionStore {
entries: HashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>, entries: HashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>,
@ -127,12 +127,13 @@ impl GroupSessionStore {
} }
/// In-memory store holding the devices of users. /// In-memory store holding the devices of users.
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct DeviceStore { pub struct DeviceStore {
entries: Arc<DashMap<UserId, DashMap<String, Device>>>, entries: Arc<DashMap<UserId, DashMap<String, Device>>>,
} }
/// A read only view over all devices belonging to a user. /// A read only view over all devices belonging to a user.
#[derive(Debug)]
pub struct UserDevices { pub struct UserDevices {
entries: ReadOnlyView<DeviceId, Device>, entries: ReadOnlyView<DeviceId, Device>,
} }
@ -192,7 +193,7 @@ impl DeviceStore {
self.entries self.entries
.get(user_id) .get(user_id)
.and_then(|m| m.remove(device_id)) .and_then(|m| m.remove(device_id))
.and_then(|(_, d)| Some(d)) .map(|(_, d)| d)
} }
/// Get a read-only view over all devices of the given user. /// Get a read-only view over all devices of the given user.

View File

@ -38,9 +38,10 @@ pub use olm_rs::{
use matrix_sdk_types::api::r0::keys::SignedKey; use matrix_sdk_types::api::r0::keys::SignedKey;
use matrix_sdk_types::identifiers::RoomId; use matrix_sdk_types::identifiers::RoomId;
/// The Olm account. /// Account holding identity keys for which sessions can be created.
///
/// An account is the central identity for encrypted communication between two /// An account is the central identity for encrypted communication between two
/// devices. It holds the two identity key pairs for a device. /// devices.
#[derive(Clone)] #[derive(Clone)]
pub struct Account { pub struct Account {
inner: Arc<Mutex<OlmAccount>>, inner: Arc<Mutex<OlmAccount>>,
@ -58,8 +59,15 @@ impl fmt::Debug for Account {
} }
} }
#[cfg_attr(tarpaulin, skip)]
impl Default for Account {
fn default() -> Self {
Self::new()
}
}
impl Account { impl Account {
/// Create a new account. /// Create a fresh new account, this will generate the identity key-pair.
pub fn new() -> Self { pub fn new() -> Self {
let account = OlmAccount::new(); let account = OlmAccount::new();
let identity_keys = account.parsed_identity_keys(); let identity_keys = account.parsed_identity_keys();
@ -182,7 +190,7 @@ impl Account {
inner: Arc::new(Mutex::new(session)), inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id), session_id: Arc::new(session_id),
sender_key: Arc::new(their_identity_key.to_owned()), sender_key: Arc::new(their_identity_key.to_owned()),
creation_time: Arc::new(now.clone()), creation_time: Arc::new(now),
last_use_time: Arc::new(now), last_use_time: Arc::new(now),
}) })
} }
@ -223,7 +231,7 @@ impl Account {
inner: Arc::new(Mutex::new(session)), inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id), session_id: Arc::new(session_id),
sender_key: Arc::new(their_identity_key.to_owned()), sender_key: Arc::new(their_identity_key.to_owned()),
creation_time: Arc::new(now.clone()), creation_time: Arc::new(now),
last_use_time: Arc::new(now), last_use_time: Arc::new(now),
}) })
} }
@ -235,10 +243,8 @@ impl PartialEq for Account {
} }
} }
/// The Olm Session. /// Cryptographic session that enables secure communication between two
/// /// `Account`s
/// Sessions are used to exchange encrypted messages between two
/// accounts/devices.
#[derive(Clone)] #[derive(Clone)]
pub struct Session { pub struct Session {
inner: Arc<Mutex<OlmSession>>, inner: Arc<Mutex<OlmSession>>,
@ -371,7 +377,7 @@ impl PartialEq for Session {
/// The private session key of a group session. /// The private session key of a group session.
/// Can be used to create a new inbound group session. /// Can be used to create a new inbound group session.
#[derive(Clone, Serialize, Zeroize)] #[derive(Clone, Debug, Serialize, Zeroize)]
#[zeroize(drop)] #[zeroize(drop)]
pub struct GroupSessionKey(pub String); pub struct GroupSessionKey(pub String);

View File

@ -52,8 +52,11 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn save_session(&mut self, session: Session) -> Result<()> { async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
self.sessions.add(session).await; for session in sessions {
let _ = self.sessions.add(session.clone()).await;
}
Ok(()) Ok(())
} }
@ -84,12 +87,13 @@ impl CryptoStore for MemoryStore {
Ok(self.tracked_users.insert(user.clone())) Ok(self.tracked_users.insert(user.clone()))
} }
#[allow(clippy::ptr_arg)]
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> { async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
Ok(self.devices.get(user_id, device_id)) Ok(self.devices.get(user_id, device_id))
} }
async fn delete_device(&self, device: Device) -> Result<()> { async fn delete_device(&self, device: Device) -> Result<()> {
self.devices.remove(device.user_id(), device.device_id()); let _ = self.devices.remove(device.user_id(), device.device_id());
Ok(()) Ok(())
} }
@ -97,8 +101,11 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.user_devices(user_id)) Ok(self.devices.user_devices(user_id))
} }
async fn save_device(&self, device: Device) -> Result<()> { async fn save_devices(&self, devices: &[Device]) -> Result<()> {
self.devices.add(device); for device in devices {
let _ = self.devices.add(device.clone());
}
Ok(()) Ok(())
} }
} }
@ -122,7 +129,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none()); assert!(store.load_account().await.unwrap().is_none());
store.save_account(account).await.unwrap(); store.save_account(account).await.unwrap();
store.save_session(session.clone()).await.unwrap(); store.save_sessions(&[session.clone()]).await.unwrap();
let sessions = store let sessions = store
.get_sessions(&session.sender_key) .get_sessions(&session.sender_key)
@ -150,7 +157,7 @@ mod test {
.unwrap(); .unwrap();
let mut store = MemoryStore::new(); let mut store = MemoryStore::new();
store let _ = store
.save_inbound_group_session(inbound.clone()) .save_inbound_group_session(inbound.clone())
.await .await
.unwrap(); .unwrap();
@ -168,7 +175,7 @@ mod test {
let device = get_device(); let device = get_device();
let store = MemoryStore::new(); let store = MemoryStore::new();
store.save_device(device.clone()).await.unwrap(); store.save_devices(&[device.clone()]).await.unwrap();
let loaded_device = store let loaded_device = store
.get_device(device.user_id(), device.device_id()) .get_device(device.user_id(), device.device_id())
@ -205,6 +212,6 @@ mod test {
let tracked_users = store.tracked_users(); let tracked_users = store.tracked_users();
tracked_users.contains(device.user_id()); let _ = tracked_users.contains(device.user_id());
} }
} }

View File

@ -37,33 +37,54 @@ pub mod sqlite;
use sqlx::Error as SqlxError; use sqlx::Error as SqlxError;
#[derive(Error, Debug)] #[derive(Error, Debug)]
/// The crypto store's error type.
pub enum CryptoStoreError { pub enum CryptoStoreError {
#[error("can't read or write from the store")] /// The account that owns the sessions, group sessions, and devices wasn't
Io(#[from] IoError), /// found.
#[error("can't finish Olm Account operation {0}")] #[error("can't save/load sessions or group sessions in the store before an account is stored")]
OlmAccount(#[from] OlmAccountError),
#[error("can't finish Olm Session operation {0}")]
OlmSession(#[from] OlmSessionError),
#[error("can't finish Olm GruoupSession operation {0}")]
OlmGroupSession(#[from] OlmGroupSessionError),
#[error("URL can't be parsed")]
UrlParse(#[from] ParseError),
#[error("error serializing data for the database")]
Serialization(#[from] SerdeError),
#[error("can't load session timestamps")]
SessionTimestampError,
#[error("can't save/load sessions or group sessions in the store before a account is stored")]
AccountUnset, AccountUnset,
/// SQL error occurred.
// TODO flatten the SqlxError to make it easier for other store // TODO flatten the SqlxError to make it easier for other store
// implementations. // implementations.
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
#[error("database error")] #[error(transparent)]
DatabaseError(#[from] SqlxError), DatabaseError(#[from] SqlxError),
/// An IO error occurred.
#[error(transparent)]
Io(#[from] IoError),
/// The underlying Olm Account operation returned an error.
#[error(transparent)]
OlmAccount(#[from] OlmAccountError),
/// The underlying Olm session operation returned an error.
#[error(transparent)]
OlmSession(#[from] OlmSessionError),
/// The underlying Olm group session operation returned an error.
#[error(transparent)]
OlmGroupSession(#[from] OlmGroupSessionError),
/// A session time-stamp couldn't be loaded.
#[error("can't load session timestamps")]
SessionTimestampError,
/// The store failed to (de)serialize a data type.
#[error(transparent)]
Serialization(#[from] SerdeError),
/// An error occurred while parsing an URL.
#[error(transparent)]
UrlParse(#[from] ParseError),
} }
pub type Result<T> = std::result::Result<T, CryptoStoreError>; pub type Result<T> = std::result::Result<T, CryptoStoreError>;
#[async_trait] #[async_trait]
/// Trait abstracting a store that the `OlmMachine` uses to store cryptographic
/// keys.
pub trait CryptoStore: Debug + Send + Sync { pub trait CryptoStore: Debug + Send + Sync {
/// Load an account that was previously stored. /// Load an account that was previously stored.
async fn load_account(&mut self) -> Result<Option<Account>>; async fn load_account(&mut self) -> Result<Option<Account>>;
@ -75,12 +96,12 @@ pub trait CryptoStore: Debug + Send + Sync {
/// * `account` - The account that should be stored. /// * `account` - The account that should be stored.
async fn save_account(&mut self, account: Account) -> Result<()>; async fn save_account(&mut self, account: Account) -> Result<()>;
/// Save the given session in the store. /// Save the given sessions in the store.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `session` - The session that should be stored. /// * `session` - The sessions that should be stored.
async fn save_session(&mut self, session: Session) -> Result<()>; async fn save_sessions(&mut self, session: &[Session]) -> Result<()>;
/// Get all the sessions that belong to the given sender key. /// Get all the sessions that belong to the given sender key.
/// ///
@ -126,12 +147,12 @@ pub trait CryptoStore: Debug + Send + Sync {
/// * `user` - The user that should be marked as tracked. /// * `user` - The user that should be marked as tracked.
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>; async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
/// Save the given device in the store. /// Save the given devices in the store.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `device` - The device that should be stored. /// * `device` - The device that should be stored.
async fn save_device(&self, device: Device) -> Result<()>; async fn save_devices(&self, devices: &[Device]) -> Result<()>;
/// Delete the given device from the store. /// Delete the given device from the store.
/// ///
@ -147,6 +168,7 @@ pub trait CryptoStore: Debug + Send + Sync {
/// * `user_id` - The user that the device belongs to. /// * `user_id` - The user that the device belongs to.
/// ///
/// * `device_id` - The unique id of the device. /// * `device_id` - The unique id of the device.
#[allow(clippy::ptr_arg)]
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>>; async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>>;
/// Get all the devices of the given user. /// Get all the devices of the given user.

View File

@ -35,6 +35,7 @@ use matrix_sdk_types::api::r0::keys::KeyAlgorithm;
use matrix_sdk_types::events::Algorithm; use matrix_sdk_types::events::Algorithm;
use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId}; use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId};
/// SQLite based implementation of a `CryptoStore`.
pub struct SqliteStore { pub struct SqliteStore {
user_id: Arc<String>, user_id: Arc<String>,
device_id: Arc<String>, device_id: Arc<String>,
@ -53,6 +54,17 @@ pub struct SqliteStore {
static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore { impl SqliteStore {
/// Open a new `SqliteStore`.
///
/// # Arguments
///
/// * `user_id` - The unique id of the user for which the store should be
/// opened.
///
/// * `device_id` - The unique id of the device for which the store should
/// be opened.
///
/// * `path` - The path where the database file should reside in.
pub async fn open<P: AsRef<Path>>( pub async fn open<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &str, device_id: &str,
@ -61,6 +73,20 @@ impl SqliteStore {
SqliteStore::open_helper(user_id, device_id, path, None).await SqliteStore::open_helper(user_id, device_id, path, None).await
} }
/// Open a new `SqliteStore`.
///
/// # Arguments
///
/// * `user_id` - The unique id of the user for which the store should be
/// opened.
///
/// * `device_id` - The unique id of the device for which the store should
/// be opened.
///
/// * `path` - The path where the database file should reside in.
///
/// * `passphrase` - The passphrase that should be used to securely store
/// the encryption keys.
pub async fn open_with_passphrase<P: AsRef<Path>>( pub async fn open_with_passphrase<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &str, device_id: &str,
@ -321,7 +347,8 @@ impl SqliteStore {
for row in rows { for row in rows {
let device_row_id = row.0; let device_row_id = row.0;
let user_id = if let Ok(u) = UserId::try_from(&row.1 as &str) { let user_id: &str = &row.1;
let user_id = if let Ok(u) = UserId::try_from(user_id) {
u u
} else { } else {
continue; continue;
@ -339,7 +366,10 @@ impl SqliteStore {
let algorithms = algorithm_rows let algorithms = algorithm_rows
.iter() .iter()
.map(|row| Algorithm::from(&row.0 as &str)) .map(|row| {
let algorithm: &str = &row.0;
Algorithm::from(algorithm)
})
.collect::<Vec<Algorithm>>(); .collect::<Vec<Algorithm>>();
let key_rows: Vec<(String, String)> = let key_rows: Vec<(String, String)> =
@ -351,7 +381,8 @@ impl SqliteStore {
let mut keys = BTreeMap::new(); let mut keys = BTreeMap::new();
for row in key_rows { for row in key_rows {
let algorithm = if let Ok(a) = KeyAlgorithm::try_from(&row.0 as &str) { let algorithm: &str = &row.0;
let algorithm = if let Ok(a) = KeyAlgorithm::try_from(algorithm) {
a a
} else { } else {
continue; continue;
@ -480,12 +511,12 @@ impl CryptoStore for SqliteStore {
let mut group_sessions = self.load_inbound_group_sessions().await?; let mut group_sessions = self.load_inbound_group_sessions().await?;
let _ = group_sessions group_sessions
.drain(..) .drain(..)
.map(|s| { .map(|s| {
self.inbound_group_sessions.add(s); self.inbound_group_sessions.add(s);
}) })
.collect::<()>(); .for_each(drop);
let devices = self.load_devices().await?; let devices = self.load_devices().await?;
mem::replace(&mut self.devices, devices); mem::replace(&mut self.devices, devices);
@ -527,32 +558,35 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn save_session(&mut self, session: Session) -> Result<()> { async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
self.lazy_load_sessions(&session.sender_key).await?; // TODO turn this into a transaction
self.sessions.add(session.clone()).await; for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?;
self.sessions.add(session.clone()).await;
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let session_id = session.session_id(); let session_id = session.session_id();
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?; let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?; let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
let pickle = session.pickle(self.get_pickle_mode()).await; let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
query( query(
"REPLACE INTO sessions ( "REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle session_id, account_id, creation_time, last_use_time, sender_key, pickle
) VALUES (?, ?, ?, ?, ?, ?)", ) VALUES (?, ?, ?, ?, ?, ?)",
) )
.bind(&session_id) .bind(&session_id)
.bind(&account_id) .bind(&account_id)
.bind(&*creation_time) .bind(&*creation_time)
.bind(&*last_use_time) .bind(&*last_use_time)
.bind(&*session.sender_key) .bind(&*session.sender_key)
.bind(&pickle) .bind(&pickle)
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;
}
Ok(()) Ok(())
} }
@ -608,15 +642,35 @@ impl CryptoStore for SqliteStore {
Ok(self.tracked_users.insert(user.clone())) Ok(self.tracked_users.insert(user.clone()))
} }
async fn save_device(&self, device: Device) -> Result<()> { async fn save_devices(&self, devices: &[Device]) -> Result<()> {
self.devices.add(device.clone()); // TODO turn this into a bulk transaction.
self.save_device_helper(device).await for device in devices {
self.devices.add(device.clone());
self.save_device_helper(device.clone()).await?
}
Ok(())
} }
async fn delete_device(&self, device: Device) -> Result<()> { async fn delete_device(&self, device: Device) -> Result<()> {
todo!() let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
query(
"DELETE FROM devices
WHERE account_id = ?1 and user_id = ?2 and device_id = ?3
",
)
.bind(account_id)
.bind(&device.user_id().to_string())
.bind(&device.device_id())
.execute(&mut *connection)
.await?;
Ok(())
} }
#[allow(clippy::ptr_arg)]
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> { async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
Ok(self.devices.get(user_id, device_id)) Ok(self.devices.get(user_id, device_id))
} }
@ -801,14 +855,14 @@ mod test {
let (mut store, _dir) = get_store(None).await; let (mut store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
assert!(store.save_session(session.clone()).await.is_err()); assert!(store.save_sessions(&[session.clone()]).await.is_err());
store store
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.save_session(session).await.unwrap(); store.save_sessions(&[session]).await.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -819,7 +873,7 @@ mod test {
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.save_session(session.clone()).await.unwrap(); store.save_sessions(&[session.clone()]).await.unwrap();
let sessions = store let sessions = store
.load_sessions_for(&session.sender_key) .load_sessions_for(&session.sender_key)
@ -841,7 +895,7 @@ mod test {
.save_account(account.clone()) .save_account(account.clone())
.await .await
.expect("Can't save account"); .expect("Can't save account");
store.save_session(session).await.unwrap(); store.save_sessions(&[session]).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
let sessions_lock = sessions.lock().await; let sessions_lock = sessions.lock().await;
@ -937,7 +991,7 @@ mod test {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
store.save_device(device.clone()).await.unwrap(); store.save_devices(&[device.clone()]).await.unwrap();
drop(store); drop(store);
@ -966,4 +1020,27 @@ mod test {
assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id());
assert_eq!(user_devices.devices().nth(0).unwrap(), &device); assert_eq!(user_devices.devices().nth(0).unwrap(), &device);
} }
#[tokio::test]
async fn device_deleting() {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
store.save_devices(&[device.clone()]).await.unwrap();
store.delete_device(device.clone()).await.unwrap();
let mut store =
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
.await
.expect("Can't create store");
store.load_account().await.unwrap();
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap();
assert!(loaded_device.is_none());
}
} }

View File

@ -12,7 +12,7 @@ version = "0.1.0"
[dependencies] [dependencies]
js_int = "0.1.5" js_int = "0.1.5"
ruma-api = "0.16.0-rc.2" ruma-api = "0.16.0-rc.3"
ruma-client-api = { version = "0.8.0-rc.5" } ruma-client-api = "0.8.0-rc.5"
ruma-events = { version = "0.21.0-beta.1" } ruma-events = "0.21.0"
ruma-identifiers = "0.16.0" ruma-identifiers = "0.16.1"