Merge branch 'master' of https://github.com/matrix-org/matrix-rust-sdk into messages
This commit is contained in:
commit
bfa9c0fda9
14 changed files with 653 additions and 321 deletions
|
@ -23,39 +23,39 @@ http = "0.2.1"
|
|||
url = "2.1.1"
|
||||
async-trait = "0.1.30"
|
||||
serde = "1.0.106"
|
||||
serde_json = "1.0.51"
|
||||
serde_json = "1.0.52"
|
||||
uuid = { version = "0.8.1", features = ["v4"] }
|
||||
|
||||
matrix-sdk-types = { path = "../matrix_sdk_types" }
|
||||
matrix-sdk-crypto = { path = "../matrix_sdk_crypto", optional = true }
|
||||
|
||||
# Misc dependencies
|
||||
thiserror = "1.0.14"
|
||||
thiserror = "1.0.16"
|
||||
tracing = "0.1.13"
|
||||
atomic = "0.4.5"
|
||||
dashmap = "3.10.0"
|
||||
dashmap = "3.11.1"
|
||||
|
||||
[dependencies.tracing-futures]
|
||||
version = "0.2.3"
|
||||
version = "0.2.4"
|
||||
default-features = false
|
||||
features = ["std", "std-future"]
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "0.2.16"
|
||||
version = "0.2.20"
|
||||
default-features = false
|
||||
features = ["sync", "time", "fs"]
|
||||
|
||||
[dependencies.sqlx]
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
optional = true
|
||||
default-features = false
|
||||
features = ["runtime-tokio", "sqlite"]
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "0.2.16", features = ["rt-threaded", "macros"] }
|
||||
ruma-identifiers = { version = "0.16.0", features = ["rand"] }
|
||||
serde_json = "1.0.51"
|
||||
tracing-subscriber = "0.2.4"
|
||||
tokio = { version = "0.2.20", features = ["rt-threaded", "macros"] }
|
||||
ruma-identifiers = { version = "0.16.1", features = ["rand"] }
|
||||
serde_json = "1.0.52"
|
||||
tracing-subscriber = "0.2.5"
|
||||
tempfile = "3.1.0"
|
||||
mockito = "0.25.1"
|
||||
lazy_static = "1.4.0"
|
||||
|
|
|
@ -111,7 +111,7 @@ impl Client {
|
|||
pub fn new(session: Option<Session>) -> Result<Self> {
|
||||
#[cfg(feature = "encryption")]
|
||||
let olm = match &session {
|
||||
Some(s) => Some(OlmMachine::new(&s.user_id, &s.device_id)?),
|
||||
Some(s) => Some(OlmMachine::new(&s.user_id, &s.device_id)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
|
@ -199,7 +199,7 @@ impl Client {
|
|||
#[cfg(feature = "encryption")]
|
||||
{
|
||||
let mut olm = self.olm.lock().await;
|
||||
*olm = Some(OlmMachine::new(&response.user_id, &response.device_id)?);
|
||||
*olm = Some(OlmMachine::new(&response.user_id, &response.device_id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -261,12 +261,15 @@ impl Client {
|
|||
/// Returns true if the room name changed, false otherwise.
|
||||
pub(crate) fn handle_push_rules(&mut self, event: &PushRulesEvent) -> bool {
|
||||
// TODO this is basically a stub
|
||||
if self.push_ruleset.as_ref() == Some(&event.content.global) {
|
||||
false
|
||||
} else {
|
||||
self.push_ruleset = Some(event.content.global.clone());
|
||||
true
|
||||
}
|
||||
// TODO ruma removed PartialEq for evens, so this doesn't work anymore.
|
||||
// Returning always true for now should be ok here since those don't
|
||||
// change often.
|
||||
// if self.push_ruleset.as_ref() == Some(&event.content.global) {
|
||||
// false
|
||||
// } else {
|
||||
self.push_ruleset = Some(event.content.global.clone());
|
||||
true
|
||||
// }
|
||||
}
|
||||
|
||||
/// Receive a timeline event for a joined room and update the client state.
|
||||
|
@ -294,16 +297,13 @@ impl Client {
|
|||
|
||||
#[cfg(feature = "encryption")]
|
||||
{
|
||||
match e {
|
||||
RoomEvent::RoomEncrypted(ref mut e) => {
|
||||
e.room_id = Some(room_id.to_owned());
|
||||
let mut olm = self.olm.lock().await;
|
||||
if let RoomEvent::RoomEncrypted(ref mut e) = e {
|
||||
e.room_id = Some(room_id.to_owned());
|
||||
let mut olm = self.olm.lock().await;
|
||||
|
||||
if let Some(o) = &mut *olm {
|
||||
decrypted_event = o.decrypt_room_event(&e).await.ok();
|
||||
}
|
||||
if let Some(o) = &mut *olm {
|
||||
decrypted_event = o.decrypt_room_event(&e).await.ok();
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -535,12 +535,15 @@ impl Client {
|
|||
) -> Result<MessageEventContent> {
|
||||
let mut olm = self.olm.lock().await;
|
||||
|
||||
match &mut *olm {
|
||||
Some(o) => Ok(MessageEventContent::Encrypted(
|
||||
o.encrypt(room_id, content).await?,
|
||||
)),
|
||||
None => panic!("Olm machine wasn't started"),
|
||||
}
|
||||
// TODO enable this again once we can send encrypted event
|
||||
// contents with ruma.
|
||||
// match &mut *olm {
|
||||
// Some(o) => Ok(MessageEventContent::Encrypted(
|
||||
// o.encrypt(room_id, content).await?,
|
||||
// )),
|
||||
// None => panic!("Olm machine wasn't started"),
|
||||
// }
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
/// Get a tuple of device and one-time keys that need to be uploaded.
|
||||
|
|
|
@ -26,7 +26,7 @@ use thiserror::Error;
|
|||
use url::ParseError;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_crypto::OlmError;
|
||||
use matrix_sdk_crypto::{MegolmError, OlmError};
|
||||
|
||||
/// Result type of the rust-sdk.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
@ -59,6 +59,9 @@ pub enum Error {
|
|||
/// An error occurred during a E2EE operation.
|
||||
#[error(transparent)]
|
||||
OlmError(#[from] OlmError),
|
||||
/// An error occurred during a E2EE group operation.
|
||||
#[error(transparent)]
|
||||
MegolmError(#[from] MegolmError),
|
||||
}
|
||||
|
||||
impl From<RumaResponseError<RumaClientError>> for Error {
|
||||
|
|
|
@ -31,7 +31,7 @@ use crate::{Result, Room, Session};
|
|||
/// When implementing `StateStore` for something other than the filesystem
|
||||
/// implement `From<ClientState> for YourDbType` this allows for easy conversion
|
||||
/// when needed in `StateStore::load/store_client_state`
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ClientState {
|
||||
/// The current sync token that should be used for the next sync call.
|
||||
pub sync_token: Option<Token>,
|
||||
|
@ -41,6 +41,12 @@ pub struct ClientState {
|
|||
pub push_ruleset: Option<Ruleset>,
|
||||
}
|
||||
|
||||
impl PartialEq for ClientState {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.sync_token == other.sync_token && self.ignored_users == other.ignored_users
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientState {
|
||||
pub fn from_base_client(client: &BaseClient) -> ClientState {
|
||||
let BaseClient {
|
||||
|
|
|
@ -101,8 +101,8 @@ impl Device {
|
|||
}
|
||||
|
||||
/// Get the key of the given key algorithm belonging to this device.
|
||||
pub fn get_key(&self, algorithm: &KeyAlgorithm) -> Option<&String> {
|
||||
self.keys.get(algorithm)
|
||||
pub fn get_key(&self, algorithm: KeyAlgorithm) -> Option<&String> {
|
||||
self.keys.get(&algorithm)
|
||||
}
|
||||
|
||||
/// Get a map containing all the device keys.
|
||||
|
@ -274,11 +274,11 @@ pub(crate) mod test {
|
|||
device.display_name().as_ref().unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
device.get_key(&KeyAlgorithm::Curve25519).unwrap(),
|
||||
device.get_key(KeyAlgorithm::Curve25519).unwrap(),
|
||||
"wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4"
|
||||
);
|
||||
assert_eq!(
|
||||
device.get_key(&KeyAlgorithm::Ed25519).unwrap(),
|
||||
device.get_key(KeyAlgorithm::Ed25519).unwrap(),
|
||||
"nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM"
|
||||
);
|
||||
}
|
||||
|
|
|
@ -19,46 +19,101 @@ use thiserror::Error;
|
|||
|
||||
use super::store::CryptoStoreError;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OlmError>;
|
||||
pub type OlmResult<T> = Result<T, OlmError>;
|
||||
pub type MegolmResult<T> = Result<T, MegolmError>;
|
||||
|
||||
/// Error representing a failure during a device to device cryptographic
|
||||
/// operation.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum OlmError {
|
||||
#[error("signature verification failed")]
|
||||
Signature(#[from] SignatureError),
|
||||
#[error("failed to read or write to the crypto store {0}")]
|
||||
Store(#[from] CryptoStoreError),
|
||||
#[error("decryption failed likely because a Olm session was wedged")]
|
||||
SessionWedged,
|
||||
#[error("the Olm message has a unsupported type")]
|
||||
UnsupportedOlmType,
|
||||
#[error("the Encrypted message has been encrypted with a unsupported algorithm.")]
|
||||
UnsupportedAlgorithm,
|
||||
#[error("the Encrypted message doesn't contain a ciphertext for our device")]
|
||||
MissingCiphertext,
|
||||
#[error("decryption failed because the session to decrypt the message is missing")]
|
||||
MissingSession,
|
||||
#[error("the Encrypted message is missing the signing key of the sender")]
|
||||
MissingSigningKey,
|
||||
/// The event that should have been decrypted is malformed.
|
||||
#[error(transparent)]
|
||||
EventError(#[from] EventError),
|
||||
|
||||
/// The received decrypted event couldn't be deserialized.
|
||||
#[error(transparent)]
|
||||
JsonError(#[from] SerdeError),
|
||||
|
||||
/// The underlying Olm session operation returned an error.
|
||||
#[error("can't finish Olm Session operation {0}")]
|
||||
OlmSession(#[from] OlmSessionError),
|
||||
|
||||
/// The underlying group session operation returned an error.
|
||||
#[error("can't finish Olm Session operation {0}")]
|
||||
OlmGroupSession(#[from] OlmGroupSessionError),
|
||||
#[error("error deserializing a string to json")]
|
||||
JsonError(#[from] SerdeError),
|
||||
#[error("the provided JSON value isn't an object")]
|
||||
NotAnObject,
|
||||
|
||||
/// The storage layer returned an error.
|
||||
#[error("failed to read or write to the crypto store {0}")]
|
||||
Store(#[from] CryptoStoreError),
|
||||
|
||||
/// The session with a device has become corrupted.
|
||||
#[error("decryption failed likely because a Olm session was wedged")]
|
||||
SessionWedged,
|
||||
}
|
||||
|
||||
pub type VerificationResult<T> = std::result::Result<T, SignatureError>;
|
||||
/// Error representing a failure during a group encryption operation.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum MegolmError {
|
||||
/// The event that should have been decrypted is malformed.
|
||||
#[error(transparent)]
|
||||
EventError(#[from] EventError),
|
||||
|
||||
/// The received decrypted event couldn't be deserialized.
|
||||
#[error(transparent)]
|
||||
JsonError(#[from] SerdeError),
|
||||
|
||||
/// Decryption failed because the session needed to decrypt the event is
|
||||
/// missing.
|
||||
#[error("decryption failed because the session to decrypt the message is missing")]
|
||||
MissingSession,
|
||||
|
||||
/// The underlying group session operation returned an error.
|
||||
#[error("can't finish Olm group session operation {0}")]
|
||||
OlmGroupSession(#[from] OlmGroupSessionError),
|
||||
|
||||
/// The storage layer returned an error.
|
||||
#[error(transparent)]
|
||||
Store(#[from] CryptoStoreError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SignatureError {
|
||||
pub enum EventError {
|
||||
#[error("the Olm message has a unsupported type")]
|
||||
UnsupportedOlmType,
|
||||
|
||||
#[error("the Encrypted message has been encrypted with a unsupported algorithm.")]
|
||||
UnsupportedAlgorithm,
|
||||
|
||||
#[error("the provided JSON value isn't an object")]
|
||||
NotAnObject,
|
||||
|
||||
#[error("the Encrypted message doesn't contain a ciphertext for our device")]
|
||||
MissingCiphertext,
|
||||
|
||||
#[error("the Encrypted message is missing the signing key of the sender")]
|
||||
MissingSigningKey,
|
||||
|
||||
#[error("the Encrypted message is missing the field {0}")]
|
||||
MissingField(String),
|
||||
|
||||
#[error("the sender of the plaintext doesn't match the sender of the encrypted message.")]
|
||||
MissmatchedSender,
|
||||
|
||||
#[error("the keys of the message don't match the keys in our database.")]
|
||||
MissmatchedKeys,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum SignatureError {
|
||||
#[error("the provided JSON value isn't an object")]
|
||||
NotAnObject,
|
||||
|
||||
#[error("the provided JSON object doesn't contain a signatures field")]
|
||||
NoSignatureFound,
|
||||
|
||||
#[error("the provided JSON object can't be converted to a canonical representation")]
|
||||
CanonicalJsonError(CjsonError),
|
||||
|
||||
#[error("the signature didn't match the provided key")]
|
||||
VerificationError,
|
||||
}
|
||||
|
|
|
@ -15,6 +15,16 @@
|
|||
//! This is the encryption part of the matrix-sdk. It contains a state machine
|
||||
//! that will aid in adding encryption support to a client library.
|
||||
|
||||
#![deny(
|
||||
missing_debug_implementations,
|
||||
missing_docs,
|
||||
trivial_casts,
|
||||
trivial_numeric_casts,
|
||||
unused_extern_crates,
|
||||
unused_import_braces,
|
||||
unused_qualifications
|
||||
)]
|
||||
|
||||
mod device;
|
||||
mod error;
|
||||
mod machine;
|
||||
|
@ -23,6 +33,10 @@ mod olm;
|
|||
mod store;
|
||||
|
||||
pub use device::{Device, TrustState};
|
||||
pub use error::OlmError;
|
||||
pub use error::{MegolmError, OlmError};
|
||||
pub use machine::{OlmMachine, OneTimeKeys};
|
||||
pub use memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
|
||||
pub use olm::{Account, InboundGroupSession, OutboundGroupSession, Session};
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
pub use store::sqlite::SqliteStore;
|
||||
pub use store::{CryptoStore, CryptoStoreError};
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -23,7 +23,7 @@ use super::olm::{InboundGroupSession, Session};
|
|||
use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId};
|
||||
|
||||
/// In-memory store for Olm Sessions.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SessionStore {
|
||||
entries: HashMap<String, Arc<Mutex<Vec<Session>>>>,
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ impl SessionStore {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Default)]
|
||||
/// In-memory store that houlds inbound group sessions.
|
||||
pub struct GroupSessionStore {
|
||||
entries: HashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>,
|
||||
|
@ -127,12 +127,13 @@ impl GroupSessionStore {
|
|||
}
|
||||
|
||||
/// In-memory store holding the devices of users.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DeviceStore {
|
||||
entries: Arc<DashMap<UserId, DashMap<String, Device>>>,
|
||||
}
|
||||
|
||||
/// A read only view over all devices belonging to a user.
|
||||
#[derive(Debug)]
|
||||
pub struct UserDevices {
|
||||
entries: ReadOnlyView<DeviceId, Device>,
|
||||
}
|
||||
|
@ -192,7 +193,7 @@ impl DeviceStore {
|
|||
self.entries
|
||||
.get(user_id)
|
||||
.and_then(|m| m.remove(device_id))
|
||||
.and_then(|(_, d)| Some(d))
|
||||
.map(|(_, d)| d)
|
||||
}
|
||||
|
||||
/// Get a read-only view over all devices of the given user.
|
||||
|
|
|
@ -38,9 +38,10 @@ pub use olm_rs::{
|
|||
use matrix_sdk_types::api::r0::keys::SignedKey;
|
||||
use matrix_sdk_types::identifiers::RoomId;
|
||||
|
||||
/// The Olm account.
|
||||
/// Account holding identity keys for which sessions can be created.
|
||||
///
|
||||
/// An account is the central identity for encrypted communication between two
|
||||
/// devices. It holds the two identity key pairs for a device.
|
||||
/// devices.
|
||||
#[derive(Clone)]
|
||||
pub struct Account {
|
||||
inner: Arc<Mutex<OlmAccount>>,
|
||||
|
@ -58,8 +59,15 @@ impl fmt::Debug for Account {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(tarpaulin, skip)]
|
||||
impl Default for Account {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Account {
|
||||
/// Create a new account.
|
||||
/// Create a fresh new account, this will generate the identity key-pair.
|
||||
pub fn new() -> Self {
|
||||
let account = OlmAccount::new();
|
||||
let identity_keys = account.parsed_identity_keys();
|
||||
|
@ -182,7 +190,7 @@ impl Account {
|
|||
inner: Arc::new(Mutex::new(session)),
|
||||
session_id: Arc::new(session_id),
|
||||
sender_key: Arc::new(their_identity_key.to_owned()),
|
||||
creation_time: Arc::new(now.clone()),
|
||||
creation_time: Arc::new(now),
|
||||
last_use_time: Arc::new(now),
|
||||
})
|
||||
}
|
||||
|
@ -223,7 +231,7 @@ impl Account {
|
|||
inner: Arc::new(Mutex::new(session)),
|
||||
session_id: Arc::new(session_id),
|
||||
sender_key: Arc::new(their_identity_key.to_owned()),
|
||||
creation_time: Arc::new(now.clone()),
|
||||
creation_time: Arc::new(now),
|
||||
last_use_time: Arc::new(now),
|
||||
})
|
||||
}
|
||||
|
@ -235,10 +243,8 @@ impl PartialEq for Account {
|
|||
}
|
||||
}
|
||||
|
||||
/// The Olm Session.
|
||||
///
|
||||
/// Sessions are used to exchange encrypted messages between two
|
||||
/// accounts/devices.
|
||||
/// Cryptographic session that enables secure communication between two
|
||||
/// `Account`s
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
inner: Arc<Mutex<OlmSession>>,
|
||||
|
@ -371,7 +377,7 @@ impl PartialEq for Session {
|
|||
|
||||
/// The private session key of a group session.
|
||||
/// Can be used to create a new inbound group session.
|
||||
#[derive(Clone, Serialize, Zeroize)]
|
||||
#[derive(Clone, Debug, Serialize, Zeroize)]
|
||||
#[zeroize(drop)]
|
||||
pub struct GroupSessionKey(pub String);
|
||||
|
||||
|
|
|
@ -52,8 +52,11 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_session(&mut self, session: Session) -> Result<()> {
|
||||
self.sessions.add(session).await;
|
||||
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
|
||||
for session in sessions {
|
||||
let _ = self.sessions.add(session.clone()).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -84,12 +87,13 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(self.tracked_users.insert(user.clone()))
|
||||
}
|
||||
|
||||
#[allow(clippy::ptr_arg)]
|
||||
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
|
||||
Ok(self.devices.get(user_id, device_id))
|
||||
}
|
||||
|
||||
async fn delete_device(&self, device: Device) -> Result<()> {
|
||||
self.devices.remove(device.user_id(), device.device_id());
|
||||
let _ = self.devices.remove(device.user_id(), device.device_id());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -97,8 +101,11 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(self.devices.user_devices(user_id))
|
||||
}
|
||||
|
||||
async fn save_device(&self, device: Device) -> Result<()> {
|
||||
self.devices.add(device);
|
||||
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
|
||||
for device in devices {
|
||||
let _ = self.devices.add(device.clone());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -122,7 +129,7 @@ mod test {
|
|||
assert!(store.load_account().await.unwrap().is_none());
|
||||
store.save_account(account).await.unwrap();
|
||||
|
||||
store.save_session(session.clone()).await.unwrap();
|
||||
store.save_sessions(&[session.clone()]).await.unwrap();
|
||||
|
||||
let sessions = store
|
||||
.get_sessions(&session.sender_key)
|
||||
|
@ -150,7 +157,7 @@ mod test {
|
|||
.unwrap();
|
||||
|
||||
let mut store = MemoryStore::new();
|
||||
store
|
||||
let _ = store
|
||||
.save_inbound_group_session(inbound.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -168,7 +175,7 @@ mod test {
|
|||
let device = get_device();
|
||||
let store = MemoryStore::new();
|
||||
|
||||
store.save_device(device.clone()).await.unwrap();
|
||||
store.save_devices(&[device.clone()]).await.unwrap();
|
||||
|
||||
let loaded_device = store
|
||||
.get_device(device.user_id(), device.device_id())
|
||||
|
@ -205,6 +212,6 @@ mod test {
|
|||
|
||||
let tracked_users = store.tracked_users();
|
||||
|
||||
tracked_users.contains(device.user_id());
|
||||
let _ = tracked_users.contains(device.user_id());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,33 +37,54 @@ pub mod sqlite;
|
|||
use sqlx::Error as SqlxError;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
/// The crypto store's error type.
|
||||
pub enum CryptoStoreError {
|
||||
#[error("can't read or write from the store")]
|
||||
Io(#[from] IoError),
|
||||
#[error("can't finish Olm Account operation {0}")]
|
||||
OlmAccount(#[from] OlmAccountError),
|
||||
#[error("can't finish Olm Session operation {0}")]
|
||||
OlmSession(#[from] OlmSessionError),
|
||||
#[error("can't finish Olm GruoupSession operation {0}")]
|
||||
OlmGroupSession(#[from] OlmGroupSessionError),
|
||||
#[error("URL can't be parsed")]
|
||||
UrlParse(#[from] ParseError),
|
||||
#[error("error serializing data for the database")]
|
||||
Serialization(#[from] SerdeError),
|
||||
#[error("can't load session timestamps")]
|
||||
SessionTimestampError,
|
||||
#[error("can't save/load sessions or group sessions in the store before a account is stored")]
|
||||
/// The account that owns the sessions, group sessions, and devices wasn't
|
||||
/// found.
|
||||
#[error("can't save/load sessions or group sessions in the store before an account is stored")]
|
||||
AccountUnset,
|
||||
|
||||
/// SQL error occurred.
|
||||
// TODO flatten the SqlxError to make it easier for other store
|
||||
// implementations.
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
#[error("database error")]
|
||||
#[error(transparent)]
|
||||
DatabaseError(#[from] SqlxError),
|
||||
|
||||
/// An IO error occurred.
|
||||
#[error(transparent)]
|
||||
Io(#[from] IoError),
|
||||
|
||||
/// The underlying Olm Account operation returned an error.
|
||||
#[error(transparent)]
|
||||
OlmAccount(#[from] OlmAccountError),
|
||||
|
||||
/// The underlying Olm session operation returned an error.
|
||||
#[error(transparent)]
|
||||
OlmSession(#[from] OlmSessionError),
|
||||
|
||||
/// The underlying Olm group session operation returned an error.
|
||||
#[error(transparent)]
|
||||
OlmGroupSession(#[from] OlmGroupSessionError),
|
||||
|
||||
/// A session time-stamp couldn't be loaded.
|
||||
#[error("can't load session timestamps")]
|
||||
SessionTimestampError,
|
||||
|
||||
/// The store failed to (de)serialize a data type.
|
||||
#[error(transparent)]
|
||||
Serialization(#[from] SerdeError),
|
||||
|
||||
/// An error occurred while parsing an URL.
|
||||
#[error(transparent)]
|
||||
UrlParse(#[from] ParseError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||
|
||||
#[async_trait]
|
||||
/// Trait abstracting a store that the `OlmMachine` uses to store cryptographic
|
||||
/// keys.
|
||||
pub trait CryptoStore: Debug + Send + Sync {
|
||||
/// Load an account that was previously stored.
|
||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||
|
@ -75,12 +96,12 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
/// * `account` - The account that should be stored.
|
||||
async fn save_account(&mut self, account: Account) -> Result<()>;
|
||||
|
||||
/// Save the given session in the store.
|
||||
/// Save the given sessions in the store.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session` - The session that should be stored.
|
||||
async fn save_session(&mut self, session: Session) -> Result<()>;
|
||||
/// * `session` - The sessions that should be stored.
|
||||
async fn save_sessions(&mut self, session: &[Session]) -> Result<()>;
|
||||
|
||||
/// Get all the sessions that belong to the given sender key.
|
||||
///
|
||||
|
@ -126,12 +147,12 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
/// * `user` - The user that should be marked as tracked.
|
||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
|
||||
|
||||
/// Save the given device in the store.
|
||||
/// Save the given devices in the store.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The device that should be stored.
|
||||
async fn save_device(&self, device: Device) -> Result<()>;
|
||||
async fn save_devices(&self, devices: &[Device]) -> Result<()>;
|
||||
|
||||
/// Delete the given device from the store.
|
||||
///
|
||||
|
@ -147,6 +168,7 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
/// * `user_id` - The user that the device belongs to.
|
||||
///
|
||||
/// * `device_id` - The unique id of the device.
|
||||
#[allow(clippy::ptr_arg)]
|
||||
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>>;
|
||||
|
||||
/// Get all the devices of the given user.
|
||||
|
|
|
@ -35,6 +35,7 @@ use matrix_sdk_types::api::r0::keys::KeyAlgorithm;
|
|||
use matrix_sdk_types::events::Algorithm;
|
||||
use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId};
|
||||
|
||||
/// SQLite based implementation of a `CryptoStore`.
|
||||
pub struct SqliteStore {
|
||||
user_id: Arc<String>,
|
||||
device_id: Arc<String>,
|
||||
|
@ -53,6 +54,17 @@ pub struct SqliteStore {
|
|||
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
|
||||
|
||||
impl SqliteStore {
|
||||
/// Open a new `SqliteStore`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `user_id` - The unique id of the user for which the store should be
|
||||
/// opened.
|
||||
///
|
||||
/// * `device_id` - The unique id of the device for which the store should
|
||||
/// be opened.
|
||||
///
|
||||
/// * `path` - The path where the database file should reside in.
|
||||
pub async fn open<P: AsRef<Path>>(
|
||||
user_id: &UserId,
|
||||
device_id: &str,
|
||||
|
@ -61,6 +73,20 @@ impl SqliteStore {
|
|||
SqliteStore::open_helper(user_id, device_id, path, None).await
|
||||
}
|
||||
|
||||
/// Open a new `SqliteStore`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `user_id` - The unique id of the user for which the store should be
|
||||
/// opened.
|
||||
///
|
||||
/// * `device_id` - The unique id of the device for which the store should
|
||||
/// be opened.
|
||||
///
|
||||
/// * `path` - The path where the database file should reside in.
|
||||
///
|
||||
/// * `passphrase` - The passphrase that should be used to securely store
|
||||
/// the encryption keys.
|
||||
pub async fn open_with_passphrase<P: AsRef<Path>>(
|
||||
user_id: &UserId,
|
||||
device_id: &str,
|
||||
|
@ -321,7 +347,8 @@ impl SqliteStore {
|
|||
|
||||
for row in rows {
|
||||
let device_row_id = row.0;
|
||||
let user_id = if let Ok(u) = UserId::try_from(&row.1 as &str) {
|
||||
let user_id: &str = &row.1;
|
||||
let user_id = if let Ok(u) = UserId::try_from(user_id) {
|
||||
u
|
||||
} else {
|
||||
continue;
|
||||
|
@ -339,7 +366,10 @@ impl SqliteStore {
|
|||
|
||||
let algorithms = algorithm_rows
|
||||
.iter()
|
||||
.map(|row| Algorithm::from(&row.0 as &str))
|
||||
.map(|row| {
|
||||
let algorithm: &str = &row.0;
|
||||
Algorithm::from(algorithm)
|
||||
})
|
||||
.collect::<Vec<Algorithm>>();
|
||||
|
||||
let key_rows: Vec<(String, String)> =
|
||||
|
@ -351,7 +381,8 @@ impl SqliteStore {
|
|||
let mut keys = BTreeMap::new();
|
||||
|
||||
for row in key_rows {
|
||||
let algorithm = if let Ok(a) = KeyAlgorithm::try_from(&row.0 as &str) {
|
||||
let algorithm: &str = &row.0;
|
||||
let algorithm = if let Ok(a) = KeyAlgorithm::try_from(algorithm) {
|
||||
a
|
||||
} else {
|
||||
continue;
|
||||
|
@ -480,12 +511,12 @@ impl CryptoStore for SqliteStore {
|
|||
|
||||
let mut group_sessions = self.load_inbound_group_sessions().await?;
|
||||
|
||||
let _ = group_sessions
|
||||
group_sessions
|
||||
.drain(..)
|
||||
.map(|s| {
|
||||
self.inbound_group_sessions.add(s);
|
||||
})
|
||||
.collect::<()>();
|
||||
.for_each(drop);
|
||||
|
||||
let devices = self.load_devices().await?;
|
||||
mem::replace(&mut self.devices, devices);
|
||||
|
@ -527,32 +558,35 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_session(&mut self, session: Session) -> Result<()> {
|
||||
self.lazy_load_sessions(&session.sender_key).await?;
|
||||
self.sessions.add(session.clone()).await;
|
||||
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
|
||||
// TODO turn this into a transaction
|
||||
for session in sessions {
|
||||
self.lazy_load_sessions(&session.sender_key).await?;
|
||||
self.sessions.add(session.clone()).await;
|
||||
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
|
||||
let session_id = session.session_id();
|
||||
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
|
||||
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
|
||||
let pickle = session.pickle(self.get_pickle_mode()).await;
|
||||
let session_id = session.session_id();
|
||||
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
|
||||
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
|
||||
let pickle = session.pickle(self.get_pickle_mode()).await;
|
||||
|
||||
let mut connection = self.connection.lock().await;
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
query(
|
||||
"REPLACE INTO sessions (
|
||||
session_id, account_id, creation_time, last_use_time, sender_key, pickle
|
||||
) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
)
|
||||
.bind(&session_id)
|
||||
.bind(&account_id)
|
||||
.bind(&*creation_time)
|
||||
.bind(&*last_use_time)
|
||||
.bind(&*session.sender_key)
|
||||
.bind(&pickle)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
query(
|
||||
"REPLACE INTO sessions (
|
||||
session_id, account_id, creation_time, last_use_time, sender_key, pickle
|
||||
) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
)
|
||||
.bind(&session_id)
|
||||
.bind(&account_id)
|
||||
.bind(&*creation_time)
|
||||
.bind(&*last_use_time)
|
||||
.bind(&*session.sender_key)
|
||||
.bind(&pickle)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -608,15 +642,35 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(self.tracked_users.insert(user.clone()))
|
||||
}
|
||||
|
||||
async fn save_device(&self, device: Device) -> Result<()> {
|
||||
self.devices.add(device.clone());
|
||||
self.save_device_helper(device).await
|
||||
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
|
||||
// TODO turn this into a bulk transaction.
|
||||
for device in devices {
|
||||
self.devices.add(device.clone());
|
||||
self.save_device_helper(device.clone()).await?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_device(&self, device: Device) -> Result<()> {
|
||||
todo!()
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
query(
|
||||
"DELETE FROM devices
|
||||
WHERE account_id = ?1 and user_id = ?2 and device_id = ?3
|
||||
",
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(&device.user_id().to_string())
|
||||
.bind(&device.device_id())
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::ptr_arg)]
|
||||
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
|
||||
Ok(self.devices.get(user_id, device_id))
|
||||
}
|
||||
|
@ -801,14 +855,14 @@ mod test {
|
|||
let (mut store, _dir) = get_store(None).await;
|
||||
let (account, session) = get_account_and_session().await;
|
||||
|
||||
assert!(store.save_session(session.clone()).await.is_err());
|
||||
assert!(store.save_sessions(&[session.clone()]).await.is_err());
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
|
||||
store.save_session(session).await.unwrap();
|
||||
store.save_sessions(&[session]).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -819,7 +873,7 @@ mod test {
|
|||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_session(session.clone()).await.unwrap();
|
||||
store.save_sessions(&[session.clone()]).await.unwrap();
|
||||
|
||||
let sessions = store
|
||||
.load_sessions_for(&session.sender_key)
|
||||
|
@ -841,7 +895,7 @@ mod test {
|
|||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_session(session).await.unwrap();
|
||||
store.save_sessions(&[session]).await.unwrap();
|
||||
|
||||
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
|
||||
let sessions_lock = sessions.lock().await;
|
||||
|
@ -937,7 +991,7 @@ mod test {
|
|||
let (_account, store, dir) = get_loaded_store().await;
|
||||
let device = get_device();
|
||||
|
||||
store.save_device(device.clone()).await.unwrap();
|
||||
store.save_devices(&[device.clone()]).await.unwrap();
|
||||
|
||||
drop(store);
|
||||
|
||||
|
@ -966,4 +1020,27 @@ mod test {
|
|||
assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id());
|
||||
assert_eq!(user_devices.devices().nth(0).unwrap(), &device);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn device_deleting() {
|
||||
let (_account, store, dir) = get_loaded_store().await;
|
||||
let device = get_device();
|
||||
|
||||
store.save_devices(&[device.clone()]).await.unwrap();
|
||||
store.delete_device(device.clone()).await.unwrap();
|
||||
|
||||
let mut store =
|
||||
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
|
||||
.await
|
||||
.expect("Can't create store");
|
||||
|
||||
store.load_account().await.unwrap();
|
||||
|
||||
let loaded_device = store
|
||||
.get_device(device.user_id(), device.device_id())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(loaded_device.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ version = "0.1.0"
|
|||
|
||||
[dependencies]
|
||||
js_int = "0.1.5"
|
||||
ruma-api = "0.16.0-rc.2"
|
||||
ruma-client-api = { version = "0.8.0-rc.5" }
|
||||
ruma-events = { version = "0.21.0-beta.1" }
|
||||
ruma-identifiers = "0.16.0"
|
||||
ruma-api = "0.16.0-rc.3"
|
||||
ruma-client-api = "0.8.0-rc.5"
|
||||
ruma-events = "0.21.0"
|
||||
ruma-identifiers = "0.16.1"
|
||||
|
|
Loading…
Reference in a new issue