Merge branch 'master' of https://github.com/matrix-org/matrix-rust-sdk into messages

This commit is contained in:
Devin R 2020-05-05 06:57:37 -04:00
commit bfa9c0fda9
14 changed files with 653 additions and 321 deletions

View file

@ -23,39 +23,39 @@ http = "0.2.1"
url = "2.1.1"
async-trait = "0.1.30"
serde = "1.0.106"
serde_json = "1.0.51"
serde_json = "1.0.52"
uuid = { version = "0.8.1", features = ["v4"] }
matrix-sdk-types = { path = "../matrix_sdk_types" }
matrix-sdk-crypto = { path = "../matrix_sdk_crypto", optional = true }
# Misc dependencies
thiserror = "1.0.14"
thiserror = "1.0.16"
tracing = "0.1.13"
atomic = "0.4.5"
dashmap = "3.10.0"
dashmap = "3.11.1"
[dependencies.tracing-futures]
version = "0.2.3"
version = "0.2.4"
default-features = false
features = ["std", "std-future"]
[dependencies.tokio]
version = "0.2.16"
version = "0.2.20"
default-features = false
features = ["sync", "time", "fs"]
[dependencies.sqlx]
version = "0.3.3"
version = "0.3.4"
optional = true
default-features = false
features = ["runtime-tokio", "sqlite"]
[dev-dependencies]
tokio = { version = "0.2.16", features = ["rt-threaded", "macros"] }
ruma-identifiers = { version = "0.16.0", features = ["rand"] }
serde_json = "1.0.51"
tracing-subscriber = "0.2.4"
tokio = { version = "0.2.20", features = ["rt-threaded", "macros"] }
ruma-identifiers = { version = "0.16.1", features = ["rand"] }
serde_json = "1.0.52"
tracing-subscriber = "0.2.5"
tempfile = "3.1.0"
mockito = "0.25.1"
lazy_static = "1.4.0"

View file

@ -111,7 +111,7 @@ impl Client {
pub fn new(session: Option<Session>) -> Result<Self> {
#[cfg(feature = "encryption")]
let olm = match &session {
Some(s) => Some(OlmMachine::new(&s.user_id, &s.device_id)?),
Some(s) => Some(OlmMachine::new(&s.user_id, &s.device_id)),
None => None,
};
@ -199,7 +199,7 @@ impl Client {
#[cfg(feature = "encryption")]
{
let mut olm = self.olm.lock().await;
*olm = Some(OlmMachine::new(&response.user_id, &response.device_id)?);
*olm = Some(OlmMachine::new(&response.user_id, &response.device_id));
}
Ok(())
@ -261,12 +261,15 @@ impl Client {
/// Returns true if the room name changed, false otherwise.
pub(crate) fn handle_push_rules(&mut self, event: &PushRulesEvent) -> bool {
// TODO this is basically a stub
if self.push_ruleset.as_ref() == Some(&event.content.global) {
false
} else {
self.push_ruleset = Some(event.content.global.clone());
true
}
// TODO ruma removed PartialEq for evens, so this doesn't work anymore.
// Returning always true for now should be ok here since those don't
// change often.
// if self.push_ruleset.as_ref() == Some(&event.content.global) {
// false
// } else {
self.push_ruleset = Some(event.content.global.clone());
true
// }
}
/// Receive a timeline event for a joined room and update the client state.
@ -294,16 +297,13 @@ impl Client {
#[cfg(feature = "encryption")]
{
match e {
RoomEvent::RoomEncrypted(ref mut e) => {
e.room_id = Some(room_id.to_owned());
let mut olm = self.olm.lock().await;
if let RoomEvent::RoomEncrypted(ref mut e) = e {
e.room_id = Some(room_id.to_owned());
let mut olm = self.olm.lock().await;
if let Some(o) = &mut *olm {
decrypted_event = o.decrypt_room_event(&e).await.ok();
}
if let Some(o) = &mut *olm {
decrypted_event = o.decrypt_room_event(&e).await.ok();
}
_ => (),
}
}
@ -535,12 +535,15 @@ impl Client {
) -> Result<MessageEventContent> {
let mut olm = self.olm.lock().await;
match &mut *olm {
Some(o) => Ok(MessageEventContent::Encrypted(
o.encrypt(room_id, content).await?,
)),
None => panic!("Olm machine wasn't started"),
}
// TODO enable this again once we can send encrypted event
// contents with ruma.
// match &mut *olm {
// Some(o) => Ok(MessageEventContent::Encrypted(
// o.encrypt(room_id, content).await?,
// )),
// None => panic!("Olm machine wasn't started"),
// }
Ok(content)
}
/// Get a tuple of device and one-time keys that need to be uploaded.

View file

@ -26,7 +26,7 @@ use thiserror::Error;
use url::ParseError;
#[cfg(feature = "encryption")]
use matrix_sdk_crypto::OlmError;
use matrix_sdk_crypto::{MegolmError, OlmError};
/// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>;
@ -59,6 +59,9 @@ pub enum Error {
/// An error occurred during a E2EE operation.
#[error(transparent)]
OlmError(#[from] OlmError),
/// An error occurred during a E2EE group operation.
#[error(transparent)]
MegolmError(#[from] MegolmError),
}
impl From<RumaResponseError<RumaClientError>> for Error {

View file

@ -31,7 +31,7 @@ use crate::{Result, Room, Session};
/// When implementing `StateStore` for something other than the filesystem
/// implement `From<ClientState> for YourDbType` this allows for easy conversion
/// when needed in `StateStore::load/store_client_state`
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ClientState {
/// The current sync token that should be used for the next sync call.
pub sync_token: Option<Token>,
@ -41,6 +41,12 @@ pub struct ClientState {
pub push_ruleset: Option<Ruleset>,
}
impl PartialEq for ClientState {
fn eq(&self, other: &Self) -> bool {
self.sync_token == other.sync_token && self.ignored_users == other.ignored_users
}
}
impl ClientState {
pub fn from_base_client(client: &BaseClient) -> ClientState {
let BaseClient {

View file

@ -101,8 +101,8 @@ impl Device {
}
/// Get the key of the given key algorithm belonging to this device.
pub fn get_key(&self, algorithm: &KeyAlgorithm) -> Option<&String> {
self.keys.get(algorithm)
pub fn get_key(&self, algorithm: KeyAlgorithm) -> Option<&String> {
self.keys.get(&algorithm)
}
/// Get a map containing all the device keys.
@ -274,11 +274,11 @@ pub(crate) mod test {
device.display_name().as_ref().unwrap()
);
assert_eq!(
device.get_key(&KeyAlgorithm::Curve25519).unwrap(),
device.get_key(KeyAlgorithm::Curve25519).unwrap(),
"wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4"
);
assert_eq!(
device.get_key(&KeyAlgorithm::Ed25519).unwrap(),
device.get_key(KeyAlgorithm::Ed25519).unwrap(),
"nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM"
);
}

View file

@ -19,46 +19,101 @@ use thiserror::Error;
use super::store::CryptoStoreError;
pub type Result<T> = std::result::Result<T, OlmError>;
pub type OlmResult<T> = Result<T, OlmError>;
pub type MegolmResult<T> = Result<T, MegolmError>;
/// Error representing a failure during a device to device cryptographic
/// operation.
#[derive(Error, Debug)]
pub enum OlmError {
#[error("signature verification failed")]
Signature(#[from] SignatureError),
#[error("failed to read or write to the crypto store {0}")]
Store(#[from] CryptoStoreError),
#[error("decryption failed likely because a Olm session was wedged")]
SessionWedged,
#[error("the Olm message has a unsupported type")]
UnsupportedOlmType,
#[error("the Encrypted message has been encrypted with a unsupported algorithm.")]
UnsupportedAlgorithm,
#[error("the Encrypted message doesn't contain a ciphertext for our device")]
MissingCiphertext,
#[error("decryption failed because the session to decrypt the message is missing")]
MissingSession,
#[error("the Encrypted message is missing the signing key of the sender")]
MissingSigningKey,
/// The event that should have been decrypted is malformed.
#[error(transparent)]
EventError(#[from] EventError),
/// The received decrypted event couldn't be deserialized.
#[error(transparent)]
JsonError(#[from] SerdeError),
/// The underlying Olm session operation returned an error.
#[error("can't finish Olm Session operation {0}")]
OlmSession(#[from] OlmSessionError),
/// The underlying group session operation returned an error.
#[error("can't finish Olm Session operation {0}")]
OlmGroupSession(#[from] OlmGroupSessionError),
#[error("error deserializing a string to json")]
JsonError(#[from] SerdeError),
#[error("the provided JSON value isn't an object")]
NotAnObject,
/// The storage layer returned an error.
#[error("failed to read or write to the crypto store {0}")]
Store(#[from] CryptoStoreError),
/// The session with a device has become corrupted.
#[error("decryption failed likely because a Olm session was wedged")]
SessionWedged,
}
pub type VerificationResult<T> = std::result::Result<T, SignatureError>;
/// Error representing a failure during a group encryption operation.
#[derive(Error, Debug)]
pub enum MegolmError {
/// The event that should have been decrypted is malformed.
#[error(transparent)]
EventError(#[from] EventError),
/// The received decrypted event couldn't be deserialized.
#[error(transparent)]
JsonError(#[from] SerdeError),
/// Decryption failed because the session needed to decrypt the event is
/// missing.
#[error("decryption failed because the session to decrypt the message is missing")]
MissingSession,
/// The underlying group session operation returned an error.
#[error("can't finish Olm group session operation {0}")]
OlmGroupSession(#[from] OlmGroupSessionError),
/// The storage layer returned an error.
#[error(transparent)]
Store(#[from] CryptoStoreError),
}
#[derive(Error, Debug)]
pub enum SignatureError {
pub enum EventError {
#[error("the Olm message has a unsupported type")]
UnsupportedOlmType,
#[error("the Encrypted message has been encrypted with a unsupported algorithm.")]
UnsupportedAlgorithm,
#[error("the provided JSON value isn't an object")]
NotAnObject,
#[error("the Encrypted message doesn't contain a ciphertext for our device")]
MissingCiphertext,
#[error("the Encrypted message is missing the signing key of the sender")]
MissingSigningKey,
#[error("the Encrypted message is missing the field {0}")]
MissingField(String),
#[error("the sender of the plaintext doesn't match the sender of the encrypted message.")]
MissmatchedSender,
#[error("the keys of the message don't match the keys in our database.")]
MissmatchedKeys,
}
#[derive(Error, Debug)]
pub(crate) enum SignatureError {
#[error("the provided JSON value isn't an object")]
NotAnObject,
#[error("the provided JSON object doesn't contain a signatures field")]
NoSignatureFound,
#[error("the provided JSON object can't be converted to a canonical representation")]
CanonicalJsonError(CjsonError),
#[error("the signature didn't match the provided key")]
VerificationError,
}

View file

@ -15,6 +15,16 @@
//! This is the encryption part of the matrix-sdk. It contains a state machine
//! that will aid in adding encryption support to a client library.
#![deny(
missing_debug_implementations,
missing_docs,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
unused_qualifications
)]
mod device;
mod error;
mod machine;
@ -23,6 +33,10 @@ mod olm;
mod store;
pub use device::{Device, TrustState};
pub use error::OlmError;
pub use error::{MegolmError, OlmError};
pub use machine::{OlmMachine, OneTimeKeys};
pub use memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
pub use olm::{Account, InboundGroupSession, OutboundGroupSession, Session};
#[cfg(feature = "sqlite-cryptostore")]
pub use store::sqlite::SqliteStore;
pub use store::{CryptoStore, CryptoStoreError};

File diff suppressed because it is too large Load diff

View file

@ -23,7 +23,7 @@ use super::olm::{InboundGroupSession, Session};
use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId};
/// In-memory store for Olm Sessions.
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct SessionStore {
entries: HashMap<String, Arc<Mutex<Vec<Session>>>>,
}
@ -69,7 +69,7 @@ impl SessionStore {
}
}
#[derive(Debug)]
#[derive(Debug, Default)]
/// In-memory store that houlds inbound group sessions.
pub struct GroupSessionStore {
entries: HashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>,
@ -127,12 +127,13 @@ impl GroupSessionStore {
}
/// In-memory store holding the devices of users.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct DeviceStore {
entries: Arc<DashMap<UserId, DashMap<String, Device>>>,
}
/// A read only view over all devices belonging to a user.
#[derive(Debug)]
pub struct UserDevices {
entries: ReadOnlyView<DeviceId, Device>,
}
@ -192,7 +193,7 @@ impl DeviceStore {
self.entries
.get(user_id)
.and_then(|m| m.remove(device_id))
.and_then(|(_, d)| Some(d))
.map(|(_, d)| d)
}
/// Get a read-only view over all devices of the given user.

View file

@ -38,9 +38,10 @@ pub use olm_rs::{
use matrix_sdk_types::api::r0::keys::SignedKey;
use matrix_sdk_types::identifiers::RoomId;
/// The Olm account.
/// Account holding identity keys for which sessions can be created.
///
/// An account is the central identity for encrypted communication between two
/// devices. It holds the two identity key pairs for a device.
/// devices.
#[derive(Clone)]
pub struct Account {
inner: Arc<Mutex<OlmAccount>>,
@ -58,8 +59,15 @@ impl fmt::Debug for Account {
}
}
#[cfg_attr(tarpaulin, skip)]
impl Default for Account {
fn default() -> Self {
Self::new()
}
}
impl Account {
/// Create a new account.
/// Create a fresh new account, this will generate the identity key-pair.
pub fn new() -> Self {
let account = OlmAccount::new();
let identity_keys = account.parsed_identity_keys();
@ -182,7 +190,7 @@ impl Account {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(their_identity_key.to_owned()),
creation_time: Arc::new(now.clone()),
creation_time: Arc::new(now),
last_use_time: Arc::new(now),
})
}
@ -223,7 +231,7 @@ impl Account {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(their_identity_key.to_owned()),
creation_time: Arc::new(now.clone()),
creation_time: Arc::new(now),
last_use_time: Arc::new(now),
})
}
@ -235,10 +243,8 @@ impl PartialEq for Account {
}
}
/// The Olm Session.
///
/// Sessions are used to exchange encrypted messages between two
/// accounts/devices.
/// Cryptographic session that enables secure communication between two
/// `Account`s
#[derive(Clone)]
pub struct Session {
inner: Arc<Mutex<OlmSession>>,
@ -371,7 +377,7 @@ impl PartialEq for Session {
/// The private session key of a group session.
/// Can be used to create a new inbound group session.
#[derive(Clone, Serialize, Zeroize)]
#[derive(Clone, Debug, Serialize, Zeroize)]
#[zeroize(drop)]
pub struct GroupSessionKey(pub String);

View file

@ -52,8 +52,11 @@ impl CryptoStore for MemoryStore {
Ok(())
}
async fn save_session(&mut self, session: Session) -> Result<()> {
self.sessions.add(session).await;
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
for session in sessions {
let _ = self.sessions.add(session.clone()).await;
}
Ok(())
}
@ -84,12 +87,13 @@ impl CryptoStore for MemoryStore {
Ok(self.tracked_users.insert(user.clone()))
}
#[allow(clippy::ptr_arg)]
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
Ok(self.devices.get(user_id, device_id))
}
async fn delete_device(&self, device: Device) -> Result<()> {
self.devices.remove(device.user_id(), device.device_id());
let _ = self.devices.remove(device.user_id(), device.device_id());
Ok(())
}
@ -97,8 +101,11 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.user_devices(user_id))
}
async fn save_device(&self, device: Device) -> Result<()> {
self.devices.add(device);
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
for device in devices {
let _ = self.devices.add(device.clone());
}
Ok(())
}
}
@ -122,7 +129,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none());
store.save_account(account).await.unwrap();
store.save_session(session.clone()).await.unwrap();
store.save_sessions(&[session.clone()]).await.unwrap();
let sessions = store
.get_sessions(&session.sender_key)
@ -150,7 +157,7 @@ mod test {
.unwrap();
let mut store = MemoryStore::new();
store
let _ = store
.save_inbound_group_session(inbound.clone())
.await
.unwrap();
@ -168,7 +175,7 @@ mod test {
let device = get_device();
let store = MemoryStore::new();
store.save_device(device.clone()).await.unwrap();
store.save_devices(&[device.clone()]).await.unwrap();
let loaded_device = store
.get_device(device.user_id(), device.device_id())
@ -205,6 +212,6 @@ mod test {
let tracked_users = store.tracked_users();
tracked_users.contains(device.user_id());
let _ = tracked_users.contains(device.user_id());
}
}

View file

@ -37,33 +37,54 @@ pub mod sqlite;
use sqlx::Error as SqlxError;
#[derive(Error, Debug)]
/// The crypto store's error type.
pub enum CryptoStoreError {
#[error("can't read or write from the store")]
Io(#[from] IoError),
#[error("can't finish Olm Account operation {0}")]
OlmAccount(#[from] OlmAccountError),
#[error("can't finish Olm Session operation {0}")]
OlmSession(#[from] OlmSessionError),
#[error("can't finish Olm GruoupSession operation {0}")]
OlmGroupSession(#[from] OlmGroupSessionError),
#[error("URL can't be parsed")]
UrlParse(#[from] ParseError),
#[error("error serializing data for the database")]
Serialization(#[from] SerdeError),
#[error("can't load session timestamps")]
SessionTimestampError,
#[error("can't save/load sessions or group sessions in the store before a account is stored")]
/// The account that owns the sessions, group sessions, and devices wasn't
/// found.
#[error("can't save/load sessions or group sessions in the store before an account is stored")]
AccountUnset,
/// SQL error occurred.
// TODO flatten the SqlxError to make it easier for other store
// implementations.
#[cfg(feature = "sqlite-cryptostore")]
#[error("database error")]
#[error(transparent)]
DatabaseError(#[from] SqlxError),
/// An IO error occurred.
#[error(transparent)]
Io(#[from] IoError),
/// The underlying Olm Account operation returned an error.
#[error(transparent)]
OlmAccount(#[from] OlmAccountError),
/// The underlying Olm session operation returned an error.
#[error(transparent)]
OlmSession(#[from] OlmSessionError),
/// The underlying Olm group session operation returned an error.
#[error(transparent)]
OlmGroupSession(#[from] OlmGroupSessionError),
/// A session time-stamp couldn't be loaded.
#[error("can't load session timestamps")]
SessionTimestampError,
/// The store failed to (de)serialize a data type.
#[error(transparent)]
Serialization(#[from] SerdeError),
/// An error occurred while parsing an URL.
#[error(transparent)]
UrlParse(#[from] ParseError),
}
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
#[async_trait]
/// Trait abstracting a store that the `OlmMachine` uses to store cryptographic
/// keys.
pub trait CryptoStore: Debug + Send + Sync {
/// Load an account that was previously stored.
async fn load_account(&mut self) -> Result<Option<Account>>;
@ -75,12 +96,12 @@ pub trait CryptoStore: Debug + Send + Sync {
/// * `account` - The account that should be stored.
async fn save_account(&mut self, account: Account) -> Result<()>;
/// Save the given session in the store.
/// Save the given sessions in the store.
///
/// # Arguments
///
/// * `session` - The session that should be stored.
async fn save_session(&mut self, session: Session) -> Result<()>;
/// * `session` - The sessions that should be stored.
async fn save_sessions(&mut self, session: &[Session]) -> Result<()>;
/// Get all the sessions that belong to the given sender key.
///
@ -126,12 +147,12 @@ pub trait CryptoStore: Debug + Send + Sync {
/// * `user` - The user that should be marked as tracked.
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
/// Save the given device in the store.
/// Save the given devices in the store.
///
/// # Arguments
///
/// * `device` - The device that should be stored.
async fn save_device(&self, device: Device) -> Result<()>;
async fn save_devices(&self, devices: &[Device]) -> Result<()>;
/// Delete the given device from the store.
///
@ -147,6 +168,7 @@ pub trait CryptoStore: Debug + Send + Sync {
/// * `user_id` - The user that the device belongs to.
///
/// * `device_id` - The unique id of the device.
#[allow(clippy::ptr_arg)]
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>>;
/// Get all the devices of the given user.

View file

@ -35,6 +35,7 @@ use matrix_sdk_types::api::r0::keys::KeyAlgorithm;
use matrix_sdk_types::events::Algorithm;
use matrix_sdk_types::identifiers::{DeviceId, RoomId, UserId};
/// SQLite based implementation of a `CryptoStore`.
pub struct SqliteStore {
user_id: Arc<String>,
device_id: Arc<String>,
@ -53,6 +54,17 @@ pub struct SqliteStore {
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore {
/// Open a new `SqliteStore`.
///
/// # Arguments
///
/// * `user_id` - The unique id of the user for which the store should be
/// opened.
///
/// * `device_id` - The unique id of the device for which the store should
/// be opened.
///
/// * `path` - The path where the database file should reside in.
pub async fn open<P: AsRef<Path>>(
user_id: &UserId,
device_id: &str,
@ -61,6 +73,20 @@ impl SqliteStore {
SqliteStore::open_helper(user_id, device_id, path, None).await
}
/// Open a new `SqliteStore`.
///
/// # Arguments
///
/// * `user_id` - The unique id of the user for which the store should be
/// opened.
///
/// * `device_id` - The unique id of the device for which the store should
/// be opened.
///
/// * `path` - The path where the database file should reside in.
///
/// * `passphrase` - The passphrase that should be used to securely store
/// the encryption keys.
pub async fn open_with_passphrase<P: AsRef<Path>>(
user_id: &UserId,
device_id: &str,
@ -321,7 +347,8 @@ impl SqliteStore {
for row in rows {
let device_row_id = row.0;
let user_id = if let Ok(u) = UserId::try_from(&row.1 as &str) {
let user_id: &str = &row.1;
let user_id = if let Ok(u) = UserId::try_from(user_id) {
u
} else {
continue;
@ -339,7 +366,10 @@ impl SqliteStore {
let algorithms = algorithm_rows
.iter()
.map(|row| Algorithm::from(&row.0 as &str))
.map(|row| {
let algorithm: &str = &row.0;
Algorithm::from(algorithm)
})
.collect::<Vec<Algorithm>>();
let key_rows: Vec<(String, String)> =
@ -351,7 +381,8 @@ impl SqliteStore {
let mut keys = BTreeMap::new();
for row in key_rows {
let algorithm = if let Ok(a) = KeyAlgorithm::try_from(&row.0 as &str) {
let algorithm: &str = &row.0;
let algorithm = if let Ok(a) = KeyAlgorithm::try_from(algorithm) {
a
} else {
continue;
@ -480,12 +511,12 @@ impl CryptoStore for SqliteStore {
let mut group_sessions = self.load_inbound_group_sessions().await?;
let _ = group_sessions
group_sessions
.drain(..)
.map(|s| {
self.inbound_group_sessions.add(s);
})
.collect::<()>();
.for_each(drop);
let devices = self.load_devices().await?;
mem::replace(&mut self.devices, devices);
@ -527,32 +558,35 @@ impl CryptoStore for SqliteStore {
Ok(())
}
async fn save_session(&mut self, session: Session) -> Result<()> {
self.lazy_load_sessions(&session.sender_key).await?;
self.sessions.add(session.clone()).await;
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
// TODO turn this into a transaction
for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?;
self.sessions.add(session.clone()).await;
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let session_id = session.session_id();
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
let pickle = session.pickle(self.get_pickle_mode()).await;
let session_id = session.session_id();
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await;
let mut connection = self.connection.lock().await;
query(
"REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle
) VALUES (?, ?, ?, ?, ?, ?)",
)
.bind(&session_id)
.bind(&account_id)
.bind(&*creation_time)
.bind(&*last_use_time)
.bind(&*session.sender_key)
.bind(&pickle)
.execute(&mut *connection)
.await?;
query(
"REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle
) VALUES (?, ?, ?, ?, ?, ?)",
)
.bind(&session_id)
.bind(&account_id)
.bind(&*creation_time)
.bind(&*last_use_time)
.bind(&*session.sender_key)
.bind(&pickle)
.execute(&mut *connection)
.await?;
}
Ok(())
}
@ -608,15 +642,35 @@ impl CryptoStore for SqliteStore {
Ok(self.tracked_users.insert(user.clone()))
}
async fn save_device(&self, device: Device) -> Result<()> {
self.devices.add(device.clone());
self.save_device_helper(device).await
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
// TODO turn this into a bulk transaction.
for device in devices {
self.devices.add(device.clone());
self.save_device_helper(device.clone()).await?
}
Ok(())
}
async fn delete_device(&self, device: Device) -> Result<()> {
todo!()
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
query(
"DELETE FROM devices
WHERE account_id = ?1 and user_id = ?2 and device_id = ?3
",
)
.bind(account_id)
.bind(&device.user_id().to_string())
.bind(&device.device_id())
.execute(&mut *connection)
.await?;
Ok(())
}
#[allow(clippy::ptr_arg)]
async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
Ok(self.devices.get(user_id, device_id))
}
@ -801,14 +855,14 @@ mod test {
let (mut store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await;
assert!(store.save_session(session.clone()).await.is_err());
assert!(store.save_sessions(&[session.clone()]).await.is_err());
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_session(session).await.unwrap();
store.save_sessions(&[session]).await.unwrap();
}
#[tokio::test]
@ -819,7 +873,7 @@ mod test {
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_session(session.clone()).await.unwrap();
store.save_sessions(&[session.clone()]).await.unwrap();
let sessions = store
.load_sessions_for(&session.sender_key)
@ -841,7 +895,7 @@ mod test {
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_session(session).await.unwrap();
store.save_sessions(&[session]).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
let sessions_lock = sessions.lock().await;
@ -937,7 +991,7 @@ mod test {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
store.save_device(device.clone()).await.unwrap();
store.save_devices(&[device.clone()]).await.unwrap();
drop(store);
@ -966,4 +1020,27 @@ mod test {
assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id());
assert_eq!(user_devices.devices().nth(0).unwrap(), &device);
}
#[tokio::test]
async fn device_deleting() {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
store.save_devices(&[device.clone()]).await.unwrap();
store.delete_device(device.clone()).await.unwrap();
let mut store =
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
.await
.expect("Can't create store");
store.load_account().await.unwrap();
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap();
assert!(loaded_device.is_none());
}
}

View file

@ -12,7 +12,7 @@ version = "0.1.0"
[dependencies]
js_int = "0.1.5"
ruma-api = "0.16.0-rc.2"
ruma-client-api = { version = "0.8.0-rc.5" }
ruma-events = { version = "0.21.0-beta.1" }
ruma-identifiers = "0.16.0"
ruma-api = "0.16.0-rc.3"
ruma-client-api = "0.8.0-rc.5"
ruma-events = "0.21.0"
ruma-identifiers = "0.16.1"