Merge branch 'lockless-cryptostore' into master
commit
42a4ad60e8
|
@ -8,7 +8,3 @@ members = [
|
||||||
"matrix_sdk_common",
|
"matrix_sdk_common",
|
||||||
"matrix_sdk_common_macros",
|
"matrix_sdk_common_macros",
|
||||||
]
|
]
|
||||||
|
|
||||||
[patch.crates-io]
|
|
||||||
olm-rs = { git = 'https://gitlab.gnome.org/jhaye/olm-rs/'}
|
|
||||||
olm-sys = { git = 'https://gitlab.gnome.org/BrainBlasted/olm-sys' }
|
|
||||||
|
|
|
@ -21,9 +21,9 @@ async-trait = "0.1.36"
|
||||||
http = "0.2.1"
|
http = "0.2.1"
|
||||||
# FIXME: Revert to regular dependency once 0.10.8 or 0.11.0 is released
|
# FIXME: Revert to regular dependency once 0.10.8 or 0.11.0 is released
|
||||||
reqwest = { git = "https://github.com/seanmonstar/reqwest", rev = "dd8441fd23dae6ffb79b4cea2862e5bca0c59743" }
|
reqwest = { git = "https://github.com/seanmonstar/reqwest", rev = "dd8441fd23dae6ffb79b4cea2862e5bca0c59743" }
|
||||||
serde_json = "1.0.56"
|
serde_json = "1.0.57"
|
||||||
thiserror = "1.0.20"
|
thiserror = "1.0.20"
|
||||||
tracing = "0.1.16"
|
tracing = "0.1.19"
|
||||||
url = "2.1.1"
|
url = "2.1.1"
|
||||||
|
|
||||||
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
|
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
|
||||||
|
@ -50,11 +50,11 @@ features = ["wasm-bindgen"]
|
||||||
async-trait = "0.1.36"
|
async-trait = "0.1.36"
|
||||||
dirs = "3.0.1"
|
dirs = "3.0.1"
|
||||||
matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" }
|
matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" }
|
||||||
tokio = { version = "0.2.21", features = ["rt-threaded", "macros"] }
|
tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] }
|
||||||
serde_json = "1.0.56"
|
serde_json = "1.0.57"
|
||||||
tracing-subscriber = "0.2.7"
|
tracing-subscriber = "0.2.11"
|
||||||
tempfile = "3.1.0"
|
tempfile = "3.1.0"
|
||||||
mockito = "0.26.0"
|
mockito = "0.27.0"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
futures = "0.3.5"
|
futures = "0.3.5"
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ pub struct Client {
|
||||||
pub(crate) base_client: BaseClient,
|
pub(crate) base_client: BaseClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl Debug for Client {
|
impl Debug for Client {
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> {
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> {
|
||||||
write!(fmt, "Client {{ homeserver: {} }}", self.homeserver)
|
write!(fmt, "Client {{ homeserver: {} }}", self.homeserver)
|
||||||
|
@ -115,7 +115,7 @@ pub struct ClientConfig {
|
||||||
pub(crate) client: Option<Arc<dyn HttpSend>>,
|
pub(crate) client: Option<Arc<dyn HttpSend>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl Debug for ClientConfig {
|
impl Debug for ClientConfig {
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
let mut res = fmt.debug_struct("ClientConfig");
|
let mut res = fmt.debug_struct("ClientConfig");
|
||||||
|
|
|
@ -18,10 +18,10 @@ sqlite-cryptostore = ["matrix-sdk-crypto/sqlite-cryptostore"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.36"
|
async-trait = "0.1.36"
|
||||||
serde = "1.0.114"
|
serde = "1.0.115"
|
||||||
serde_json = "1.0.56"
|
serde_json = "1.0.57"
|
||||||
zeroize = "1.1.0"
|
zeroize = "1.1.0"
|
||||||
tracing = "0.1.16"
|
tracing = "0.1.19"
|
||||||
|
|
||||||
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
|
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
|
||||||
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
|
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
|
||||||
|
@ -31,19 +31,19 @@ matrix-sdk-crypto = { version = "0.1.0", path = "../matrix_sdk_crypto", optional
|
||||||
thiserror = "1.0.20"
|
thiserror = "1.0.20"
|
||||||
|
|
||||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio]
|
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio]
|
||||||
version = "0.2.21"
|
version = "0.2.22"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = ["sync", "fs"]
|
features = ["sync", "fs"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" }
|
matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" }
|
||||||
http = "0.2.1"
|
http = "0.2.1"
|
||||||
tracing-subscriber = "0.2.7"
|
tracing-subscriber = "0.2.11"
|
||||||
tempfile = "3.1.0"
|
tempfile = "3.1.0"
|
||||||
mockito = "0.26.0"
|
mockito = "0.27.0"
|
||||||
|
|
||||||
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
|
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
|
||||||
tokio = { version = "0.2.21", features = ["rt-threaded", "macros"] }
|
tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] }
|
||||||
|
|
||||||
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
|
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
|
||||||
wasm-bindgen-test = "0.3.15"
|
wasm-bindgen-test = "0.3.17"
|
||||||
|
|
|
@ -212,7 +212,7 @@ pub struct BaseClient {
|
||||||
store_passphrase: Arc<Zeroizing<String>>,
|
store_passphrase: Arc<Zeroizing<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl fmt::Debug for BaseClient {
|
impl fmt::Debug for BaseClient {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("Client")
|
f.debug_struct("Client")
|
||||||
|
@ -246,7 +246,7 @@ pub struct BaseClientConfig {
|
||||||
passphrase: Option<Zeroizing<String>>,
|
passphrase: Option<Zeroizing<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl std::fmt::Debug for BaseClientConfig {
|
impl std::fmt::Debug for BaseClientConfig {
|
||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
|
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
|
||||||
fmt.debug_struct("BaseClientConfig").finish()
|
fmt.debug_struct("BaseClientConfig").finish()
|
||||||
|
|
|
@ -12,7 +12,7 @@ version = "0.1.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
instant = { version = "0.1.6", features = ["wasm-bindgen", "now"] }
|
instant = { version = "0.1.6", features = ["wasm-bindgen", "now"] }
|
||||||
js_int = "0.1.8"
|
js_int = "0.1.9"
|
||||||
|
|
||||||
[dependencies.ruma]
|
[dependencies.ruma]
|
||||||
version = "0.0.1"
|
version = "0.0.1"
|
||||||
|
@ -24,7 +24,7 @@ features = ["client-api"]
|
||||||
uuid = { version = "0.8.1", features = ["v4"] }
|
uuid = { version = "0.8.1", features = ["v4"] }
|
||||||
|
|
||||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio]
|
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio]
|
||||||
version = "0.2.21"
|
version = "0.2.22"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = ["sync", "time", "fs"]
|
features = ["sync", "time", "fs"]
|
||||||
|
|
||||||
|
|
|
@ -20,18 +20,18 @@ async-trait = "0.1.36"
|
||||||
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
|
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
|
||||||
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
|
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
|
||||||
|
|
||||||
olm-rs = { version = "0.5.0", features = ["serde"] }
|
olm-rs = { git = 'https://gitlab.gnome.org/jhaye/olm-rs/', features = ["serde"]}
|
||||||
serde = { version = "1.0.114", features = ["derive"] }
|
serde = { version = "1.0.115", features = ["derive"] }
|
||||||
serde_json = "1.0.56"
|
serde_json = "1.0.57"
|
||||||
cjson = "0.1.1"
|
cjson = "0.1.1"
|
||||||
zeroize = { version = "1.1.0", features = ["zeroize_derive"] }
|
zeroize = { version = "1.1.0", features = ["zeroize_derive"] }
|
||||||
url = "2.1.1"
|
url = "2.1.1"
|
||||||
|
|
||||||
# Misc dependencies
|
# Misc dependencies
|
||||||
thiserror = "1.0.20"
|
thiserror = "1.0.20"
|
||||||
tracing = "0.1.16"
|
tracing = "0.1.19"
|
||||||
atomic = "0.4.6"
|
atomic = "0.5.0"
|
||||||
dashmap = "3.11.7"
|
dashmap = "3.11.10"
|
||||||
|
|
||||||
[dependencies.tracing-futures]
|
[dependencies.tracing-futures]
|
||||||
version = "0.2.4"
|
version = "0.2.4"
|
||||||
|
|
|
@ -198,7 +198,7 @@ impl Device {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub async fn from_machine(machine: &OlmMachine) -> Device {
|
pub async fn from_machine(machine: &OlmMachine) -> Device {
|
||||||
Device::from_account(&machine.account).await
|
Device::from_account(machine.account()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -41,7 +41,6 @@ use matrix_sdk_common::{
|
||||||
Algorithm, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent,
|
Algorithm, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent,
|
||||||
},
|
},
|
||||||
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
|
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
uuid::Uuid,
|
uuid::Uuid,
|
||||||
Raw,
|
Raw,
|
||||||
};
|
};
|
||||||
|
@ -76,11 +75,11 @@ pub struct OlmMachine {
|
||||||
/// The unique device id of the device that holds this account.
|
/// The unique device id of the device that holds this account.
|
||||||
device_id: Box<DeviceId>,
|
device_id: Box<DeviceId>,
|
||||||
/// Our underlying Olm Account holding our identity keys.
|
/// Our underlying Olm Account holding our identity keys.
|
||||||
pub(crate) account: Account,
|
account: Account,
|
||||||
/// Store for the encryption keys.
|
/// Store for the encryption keys.
|
||||||
/// Persists all the encryption keys so a client can resume the session
|
/// Persists all the encryption keys so a client can resume the session
|
||||||
/// without the need to create new keys.
|
/// without the need to create new keys.
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
/// The currently active outbound group sessions.
|
/// The currently active outbound group sessions.
|
||||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
||||||
/// A state machine that is responsible to handle and keep track of SAS
|
/// A state machine that is responsible to handle and keep track of SAS
|
||||||
|
@ -88,7 +87,7 @@ pub struct OlmMachine {
|
||||||
verification_machine: VerificationMachine,
|
verification_machine: VerificationMachine,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl std::fmt::Debug for OlmMachine {
|
impl std::fmt::Debug for OlmMachine {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("OlmMachine")
|
f.debug_struct("OlmMachine")
|
||||||
|
@ -111,10 +110,9 @@ impl OlmMachine {
|
||||||
/// * `user_id` - The unique id of the user that owns this machine.
|
/// * `user_id` - The unique id of the user that owns this machine.
|
||||||
///
|
///
|
||||||
/// * `device_id` - The unique id of the device that owns this machine.
|
/// * `device_id` - The unique id of the device that owns this machine.
|
||||||
#[allow(clippy::ptr_arg)]
|
|
||||||
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
|
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
|
||||||
let store: Box<dyn CryptoStore> = Box::new(MemoryStore::new());
|
let store: Box<dyn CryptoStore> = Box::new(MemoryStore::new());
|
||||||
let store = Arc::new(RwLock::new(store));
|
let store = Arc::new(store);
|
||||||
let account = Account::new(user_id, device_id);
|
let account = Account::new(user_id, device_id);
|
||||||
|
|
||||||
OlmMachine {
|
OlmMachine {
|
||||||
|
@ -147,7 +145,7 @@ impl OlmMachine {
|
||||||
pub async fn new_with_store(
|
pub async fn new_with_store(
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
device_id: Box<DeviceId>,
|
device_id: Box<DeviceId>,
|
||||||
mut store: Box<dyn CryptoStore>,
|
store: Box<dyn CryptoStore>,
|
||||||
) -> StoreResult<Self> {
|
) -> StoreResult<Self> {
|
||||||
let account = match store.load_account().await? {
|
let account = match store.load_account().await? {
|
||||||
Some(a) => {
|
Some(a) => {
|
||||||
|
@ -160,7 +158,7 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let store = Arc::new(RwLock::new(store));
|
let store = Arc::new(store);
|
||||||
let verification_machine = VerificationMachine::new(account.clone(), store.clone());
|
let verification_machine = VerificationMachine::new(account.clone(), store.clone());
|
||||||
|
|
||||||
Ok(OlmMachine {
|
Ok(OlmMachine {
|
||||||
|
@ -216,6 +214,12 @@ impl OlmMachine {
|
||||||
self.account.should_upload_keys().await
|
self.account.should_upload_keys().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the underlying Olm account of the machine.
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn account(&self) -> &Account {
|
||||||
|
&self.account
|
||||||
|
}
|
||||||
|
|
||||||
/// Update the count of one-time keys that are currently on the server.
|
/// Update the count of one-time keys that are currently on the server.
|
||||||
fn update_key_count(&self, count: u64) {
|
fn update_key_count(&self, count: u64) {
|
||||||
self.account.update_uploaded_key_count(count);
|
self.account.update_uploaded_key_count(count);
|
||||||
|
@ -250,11 +254,7 @@ impl OlmMachine {
|
||||||
self.update_key_count(count);
|
self.update_key_count(count);
|
||||||
|
|
||||||
self.account.mark_keys_as_published().await;
|
self.account.mark_keys_as_published().await;
|
||||||
self.store
|
self.store.save_account(self.account.clone()).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_account(self.account.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -285,7 +285,7 @@ impl OlmMachine {
|
||||||
let mut missing = BTreeMap::new();
|
let mut missing = BTreeMap::new();
|
||||||
|
|
||||||
for user_id in users {
|
for user_id in users {
|
||||||
let user_devices = self.store.read().await.get_user_devices(user_id).await?;
|
let user_devices = self.store.get_user_devices(user_id).await?;
|
||||||
|
|
||||||
for device in user_devices.devices() {
|
for device in user_devices.devices() {
|
||||||
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
|
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
|
||||||
|
@ -294,7 +294,7 @@ impl OlmMachine {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
let sessions = self.store.write().await.get_sessions(sender_key).await?;
|
let sessions = self.store.get_sessions(sender_key).await?;
|
||||||
|
|
||||||
let is_missing = if let Some(sessions) = sessions {
|
let is_missing = if let Some(sessions) = sessions {
|
||||||
sessions.lock().await.is_empty()
|
sessions.lock().await.is_empty()
|
||||||
|
@ -333,13 +333,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
for (user_id, user_devices) in &response.one_time_keys {
|
for (user_id, user_devices) in &response.one_time_keys {
|
||||||
for (device_id, key_map) in user_devices {
|
for (device_id, key_map) in user_devices {
|
||||||
let device: Device = match self
|
let device: Device = match self.store.get_device(&user_id, device_id).await {
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&user_id, device_id)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(Some(d)) => d,
|
Ok(Some(d)) => d,
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
warn!(
|
warn!(
|
||||||
|
@ -368,7 +362,7 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = self.store.write().await.save_sessions(&[session]).await {
|
if let Err(e) = self.store.save_sessions(&[session]).await {
|
||||||
error!("Failed to store newly created Olm session {}", e);
|
error!("Failed to store newly created Olm session {}", e);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -389,11 +383,7 @@ impl OlmMachine {
|
||||||
let mut changed_devices = Vec::new();
|
let mut changed_devices = Vec::new();
|
||||||
|
|
||||||
for (user_id, device_map) in device_keys_map {
|
for (user_id, device_map) in device_keys_map {
|
||||||
self.store
|
self.store.update_tracked_user(user_id, false).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.update_tracked_user(user_id, false)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
for (device_id, device_keys) in device_map.iter() {
|
for (device_id, device_keys) in device_map.iter() {
|
||||||
// We don't need our own device in the device store.
|
// We don't need our own device in the device store.
|
||||||
|
@ -409,12 +399,7 @@ impl OlmMachine {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let device = self
|
let device = self.store.get_device(&user_id, device_id).await?;
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&user_id, device_id)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let device = if let Some(mut device) = device {
|
let device = if let Some(mut device) = device {
|
||||||
if let Err(e) = device.update_device(device_keys) {
|
if let Err(e) = device.update_device(device_keys) {
|
||||||
|
@ -445,13 +430,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
let current_devices: HashSet<&DeviceId> =
|
let current_devices: HashSet<&DeviceId> =
|
||||||
device_map.keys().map(|id| id.as_ref()).collect();
|
device_map.keys().map(|id| id.as_ref()).collect();
|
||||||
let stored_devices = self
|
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_user_devices(&user_id)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
|
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
|
||||||
|
|
||||||
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
||||||
|
@ -459,7 +438,7 @@ impl OlmMachine {
|
||||||
for device_id in deleted_devices {
|
for device_id in deleted_devices {
|
||||||
if let Some(device) = stored_devices.get(device_id) {
|
if let Some(device) = stored_devices.get(device_id) {
|
||||||
device.mark_as_deleted();
|
device.mark_as_deleted();
|
||||||
self.store.write().await.delete_device(device).await?;
|
self.store.delete_device(device).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -483,11 +462,7 @@ impl OlmMachine {
|
||||||
let changed_devices = self
|
let changed_devices = self
|
||||||
.handle_devices_from_key_query(&response.device_keys)
|
.handle_devices_from_key_query(&response.device_keys)
|
||||||
.await?;
|
.await?;
|
||||||
self.store
|
self.store.save_devices(&changed_devices).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_devices(&changed_devices)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(changed_devices)
|
Ok(changed_devices)
|
||||||
}
|
}
|
||||||
|
@ -511,7 +486,7 @@ impl OlmMachine {
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
message: &OlmMessage,
|
message: &OlmMessage,
|
||||||
) -> OlmResult<Option<String>> {
|
) -> OlmResult<Option<String>> {
|
||||||
let s = self.store.write().await.get_sessions(sender_key).await?;
|
let s = self.store.get_sessions(sender_key).await?;
|
||||||
|
|
||||||
// We don't have any existing sessions, return early.
|
// We don't have any existing sessions, return early.
|
||||||
let sessions = if let Some(s) = s {
|
let sessions = if let Some(s) = s {
|
||||||
|
@ -561,7 +536,7 @@ impl OlmMachine {
|
||||||
// Decryption was successful, save the new ratchet state of the
|
// Decryption was successful, save the new ratchet state of the
|
||||||
// session that was used to decrypt the message.
|
// session that was used to decrypt the message.
|
||||||
trace!("Saved the new session state for {}", sender);
|
trace!("Saved the new session state for {}", sender);
|
||||||
self.store.write().await.save_sessions(&[session]).await?;
|
self.store.save_sessions(&[session]).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(plaintext)
|
Ok(plaintext)
|
||||||
|
@ -616,11 +591,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
// Save the account since we remove the one-time key that
|
// Save the account since we remove the one-time key that
|
||||||
// was used to create this session.
|
// was used to create this session.
|
||||||
self.store
|
self.store.save_account(self.account.clone()).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_account(self.account.clone())
|
|
||||||
.await?;
|
|
||||||
session
|
session
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -630,7 +601,7 @@ impl OlmMachine {
|
||||||
let plaintext = session.decrypt(message).await?;
|
let plaintext = session.decrypt(message).await?;
|
||||||
|
|
||||||
// Save the new ratcheted state of the session.
|
// Save the new ratcheted state of the session.
|
||||||
self.store.write().await.save_sessions(&[session]).await?;
|
self.store.save_sessions(&[session]).await?;
|
||||||
plaintext
|
plaintext
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -781,12 +752,7 @@ impl OlmMachine {
|
||||||
&event.content.room_id,
|
&event.content.room_id,
|
||||||
session_key,
|
session_key,
|
||||||
)?;
|
)?;
|
||||||
let _ = self
|
let _ = self.store.save_inbound_group_session(session).await?;
|
||||||
.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_inbound_group_session(session)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
|
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
|
||||||
Ok(Some(event))
|
Ok(Some(event))
|
||||||
|
@ -808,12 +774,7 @@ impl OlmMachine {
|
||||||
async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> {
|
async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> {
|
||||||
let (outbound, inbound) = self.account.create_group_session_pair(room_id).await;
|
let (outbound, inbound) = self.account.create_group_session_pair(room_id).await;
|
||||||
|
|
||||||
let _ = self
|
let _ = self.store.save_inbound_group_session(inbound).await?;
|
||||||
.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_inbound_group_session(inbound)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let _ = self
|
let _ = self
|
||||||
.outbound_group_sessions
|
.outbound_group_sessions
|
||||||
|
@ -899,8 +860,7 @@ impl OlmMachine {
|
||||||
return Err(EventError::MissingSenderKey.into());
|
return Err(EventError::MissingSenderKey.into());
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut session = if let Some(s) = self.store.write().await.get_sessions(sender_key).await?
|
let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? {
|
||||||
{
|
|
||||||
let session = &s.lock().await[0];
|
let session = &s.lock().await[0];
|
||||||
session.clone()
|
session.clone()
|
||||||
} else {
|
} else {
|
||||||
|
@ -914,7 +874,7 @@ impl OlmMachine {
|
||||||
};
|
};
|
||||||
|
|
||||||
let message = session.encrypt(recipient_device, event_type, content).await;
|
let message = session.encrypt(recipient_device, event_type, content).await;
|
||||||
self.store.write().await.save_sessions(&[session]).await?;
|
self.store.save_sessions(&[session]).await?;
|
||||||
|
|
||||||
message
|
message
|
||||||
}
|
}
|
||||||
|
@ -969,7 +929,7 @@ impl OlmMachine {
|
||||||
panic!("Session is already shared");
|
panic!("Session is already shared");
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO don't mark the session as shared automatically only, when all
|
// TODO don't mark the session as shared automatically, only when all
|
||||||
// the requests are done, failure to send these requests will likely end
|
// the requests are done, failure to send these requests will likely end
|
||||||
// up in wedged sessions. We'll need to store the requests and let the
|
// up in wedged sessions. We'll need to store the requests and let the
|
||||||
// caller mark them as sent using an UUID.
|
// caller mark them as sent using an UUID.
|
||||||
|
@ -978,15 +938,7 @@ impl OlmMachine {
|
||||||
let mut devices = Vec::new();
|
let mut devices = Vec::new();
|
||||||
|
|
||||||
for user_id in users {
|
for user_id in users {
|
||||||
for device in self
|
for device in self.store.get_user_devices(user_id).await?.devices() {
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_user_devices(user_id)
|
|
||||||
.await?
|
|
||||||
.devices()
|
|
||||||
{
|
|
||||||
// TODO abort if the device isn't verified
|
|
||||||
devices.push(device.clone());
|
devices.push(device.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1193,8 +1145,6 @@ impl OlmMachine {
|
||||||
|
|
||||||
let session = self
|
let session = self
|
||||||
.store
|
.store
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.get_inbound_group_session(room_id, &content.sender_key, &content.session_id)
|
.get_inbound_group_session(room_id, &content.sender_key, &content.session_id)
|
||||||
.await?;
|
.await?;
|
||||||
// TODO check if the Olm session is wedged and re-request the key.
|
// TODO check if the Olm session is wedged and re-request the key.
|
||||||
|
@ -1220,12 +1170,8 @@ impl OlmMachine {
|
||||||
///
|
///
|
||||||
/// Returns true if the user was queued up for a key query, false otherwise.
|
/// Returns true if the user was queued up for a key query, false otherwise.
|
||||||
pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult<bool> {
|
pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult<bool> {
|
||||||
if self.store.read().await.tracked_users().contains(user_id) {
|
if self.store.is_user_tracked(user_id) {
|
||||||
self.store
|
self.store.update_tracked_user(user_id, true).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.update_tracked_user(user_id, true)
|
|
||||||
.await?;
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
Ok(false)
|
Ok(false)
|
||||||
|
@ -1251,17 +1197,11 @@ impl OlmMachine {
|
||||||
I: IntoIterator<Item = &'a UserId>,
|
I: IntoIterator<Item = &'a UserId>,
|
||||||
{
|
{
|
||||||
for user in users {
|
for user in users {
|
||||||
if self.store.read().await.tracked_users().contains(user) {
|
if self.store.is_user_tracked(user) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = self
|
if let Err(e) = self.store.update_tracked_user(user, true).await {
|
||||||
.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.update_tracked_user(user, true)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
warn!("Error storing users for tracking {}", e);
|
warn!("Error storing users for tracking {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1269,14 +1209,14 @@ impl OlmMachine {
|
||||||
|
|
||||||
/// Should the client perform a key query request.
|
/// Should the client perform a key query request.
|
||||||
pub async fn should_query_keys(&self) -> bool {
|
pub async fn should_query_keys(&self) -> bool {
|
||||||
!self.store.read().await.users_for_key_query().is_empty()
|
self.store.has_users_for_key_query()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the set of users that we need to query keys for.
|
/// Get the set of users that we need to query keys for.
|
||||||
///
|
///
|
||||||
/// Returns a hash set of users that need to be queried for keys.
|
/// Returns a hash set of users that need to be queried for keys.
|
||||||
pub async fn users_for_key_query(&self) -> HashSet<UserId> {
|
pub async fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
self.store.read().await.users_for_key_query().clone()
|
self.store.users_for_key_query()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1399,19 +1339,8 @@ mod test {
|
||||||
|
|
||||||
let alice_deivce = Device::from_machine(&alice).await;
|
let alice_deivce = Device::from_machine(&alice).await;
|
||||||
let bob_device = Device::from_machine(&bob).await;
|
let bob_device = Device::from_machine(&bob).await;
|
||||||
alice
|
alice.store.save_devices(&[bob_device]).await.unwrap();
|
||||||
.store
|
bob.store.save_devices(&[alice_deivce]).await.unwrap();
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_devices(&[bob_device])
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
bob.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_devices(&[alice_deivce])
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
(alice, bob, otk)
|
(alice, bob, otk)
|
||||||
}
|
}
|
||||||
|
@ -1444,8 +1373,6 @@ mod test {
|
||||||
|
|
||||||
let bob_device = alice
|
let bob_device = alice
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&bob.user_id, &bob.device_id)
|
.get_device(&bob.user_id, &bob.device_id)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1650,13 +1577,7 @@ mod test {
|
||||||
let alice_id = user_id!("@alice:example.org");
|
let alice_id = user_id!("@alice:example.org");
|
||||||
let alice_device_id: &DeviceId = "JLAFKJWSCS".into();
|
let alice_device_id: &DeviceId = "JLAFKJWSCS".into();
|
||||||
|
|
||||||
let alice_devices = machine
|
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_user_devices(&alice_id)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(alice_devices.devices().peekable().peek().is_none());
|
assert!(alice_devices.devices().peekable().peek().is_none());
|
||||||
|
|
||||||
machine
|
machine
|
||||||
|
@ -1666,8 +1587,6 @@ mod test {
|
||||||
|
|
||||||
let device = machine
|
let device = machine
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&alice_id, alice_device_id)
|
.get_device(&alice_id, alice_device_id)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1719,8 +1638,6 @@ mod test {
|
||||||
|
|
||||||
let session = alice_machine
|
let session = alice_machine
|
||||||
.store
|
.store
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.get_sessions(bob_machine.account.identity_keys().curve25519())
|
.get_sessions(bob_machine.account.identity_keys().curve25519())
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1735,8 +1652,6 @@ mod test {
|
||||||
|
|
||||||
let bob_device = alice
|
let bob_device = alice
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&bob.user_id, &bob.device_id)
|
.get_device(&bob.user_id, &bob.device_id)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1798,8 +1713,6 @@ mod test {
|
||||||
|
|
||||||
let session = bob
|
let session = bob
|
||||||
.store
|
.store
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.get_inbound_group_session(
|
.get_inbound_group_session(
|
||||||
&room_id,
|
&room_id,
|
||||||
alice.account.identity_keys().curve25519(),
|
alice.account.identity_keys().curve25519(),
|
||||||
|
|
|
@ -43,7 +43,7 @@ impl SessionStore {
|
||||||
///
|
///
|
||||||
/// Returns true if the the session was added, false if the session was
|
/// Returns true if the the session was added, false if the session was
|
||||||
/// already in the store.
|
/// already in the store.
|
||||||
pub async fn add(&mut self, session: Session) -> bool {
|
pub async fn add(&self, session: Session) -> bool {
|
||||||
if !self.entries.contains_key(&*session.sender_key) {
|
if !self.entries.contains_key(&*session.sender_key) {
|
||||||
self.entries.insert(
|
self.entries.insert(
|
||||||
session.sender_key.to_string(),
|
session.sender_key.to_string(),
|
||||||
|
@ -62,11 +62,12 @@ impl SessionStore {
|
||||||
|
|
||||||
/// Get all the sessions that belong to the given sender key.
|
/// Get all the sessions that belong to the given sender key.
|
||||||
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
|
pub fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
|
||||||
|
#[allow(clippy::map_clone)]
|
||||||
self.entries.get(sender_key).map(|s| s.clone())
|
self.entries.get(sender_key).map(|s| s.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a list of sessions belonging to the sender key.
|
/// Add a list of sessions belonging to the sender key.
|
||||||
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Session>) {
|
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
|
||||||
self.entries
|
self.entries
|
||||||
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
|
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
|
||||||
}
|
}
|
||||||
|
@ -75,6 +76,7 @@ impl SessionStore {
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Default, Clone)]
|
||||||
/// In-memory store that holds inbound group sessions.
|
/// In-memory store that holds inbound group sessions.
|
||||||
pub struct GroupSessionStore {
|
pub struct GroupSessionStore {
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
entries: Arc<DashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>>,
|
entries: Arc<DashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,7 +92,7 @@ impl GroupSessionStore {
|
||||||
///
|
///
|
||||||
/// Returns true if the the session was added, false if the session was
|
/// Returns true if the the session was added, false if the session was
|
||||||
/// already in the store.
|
/// already in the store.
|
||||||
pub fn add(&mut self, session: InboundGroupSession) -> bool {
|
pub fn add(&self, session: InboundGroupSession) -> bool {
|
||||||
if !self.entries.contains_key(&session.room_id) {
|
if !self.entries.contains_key(&session.room_id) {
|
||||||
let room_id = &*session.room_id;
|
let room_id = &*session.room_id;
|
||||||
self.entries.insert(room_id.clone(), HashMap::new());
|
self.entries.insert(room_id.clone(), HashMap::new());
|
||||||
|
@ -223,7 +225,7 @@ mod test {
|
||||||
async fn test_session_store() {
|
async fn test_session_store() {
|
||||||
let (_, session) = get_account_and_session().await;
|
let (_, session) = get_account_and_session().await;
|
||||||
|
|
||||||
let mut store = SessionStore::new();
|
let store = SessionStore::new();
|
||||||
|
|
||||||
assert!(store.add(session.clone()).await);
|
assert!(store.add(session.clone()).await);
|
||||||
assert!(!store.add(session.clone()).await);
|
assert!(!store.add(session.clone()).await);
|
||||||
|
@ -240,7 +242,7 @@ mod test {
|
||||||
async fn test_session_store_bulk_storing() {
|
async fn test_session_store_bulk_storing() {
|
||||||
let (_, session) = get_account_and_session().await;
|
let (_, session) = get_account_and_session().await;
|
||||||
|
|
||||||
let mut store = SessionStore::new();
|
let store = SessionStore::new();
|
||||||
store.set_for_sender(&session.sender_key, vec![session.clone()]);
|
store.set_for_sender(&session.sender_key, vec![session.clone()]);
|
||||||
|
|
||||||
let sessions = store.get(&session.sender_key).unwrap();
|
let sessions = store.get(&session.sender_key).unwrap();
|
||||||
|
@ -271,7 +273,7 @@ mod test {
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut store = GroupSessionStore::new();
|
let store = GroupSessionStore::new();
|
||||||
store.add(inbound.clone());
|
store.add(inbound.clone());
|
||||||
|
|
||||||
let loaded_session = store
|
let loaded_session = store
|
||||||
|
|
|
@ -63,7 +63,7 @@ pub struct Account {
|
||||||
uploaded_signed_key_count: Arc<AtomicI64>,
|
uploaded_signed_key_count: Arc<AtomicI64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl fmt::Debug for Account {
|
impl fmt::Debug for Account {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("Account")
|
f.debug_struct("Account")
|
||||||
|
|
|
@ -222,7 +222,7 @@ impl InboundGroupSession {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl fmt::Debug for InboundGroupSession {
|
impl fmt::Debug for InboundGroupSession {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("InboundGroupSession")
|
f.debug_struct("InboundGroupSession")
|
||||||
|
@ -401,7 +401,7 @@ impl OutboundGroupSession {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl std::fmt::Debug for OutboundGroupSession {
|
impl std::fmt::Debug for OutboundGroupSession {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("OutboundGroupSession")
|
f.debug_struct("OutboundGroupSession")
|
||||||
|
|
|
@ -51,7 +51,7 @@ pub struct Session {
|
||||||
pub(crate) last_use_time: Arc<Instant>,
|
pub(crate) last_use_time: Arc<Instant>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl fmt::Debug for Session {
|
impl fmt::Debug for Session {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("Session")
|
f.debug_struct("Session")
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
use std::{collections::HashSet, sync::Arc};
|
use std::{collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use dashmap::DashSet;
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
identifiers::{DeviceId, RoomId, UserId},
|
identifiers::{DeviceId, RoomId, UserId},
|
||||||
locks::Mutex,
|
locks::Mutex,
|
||||||
|
@ -25,12 +26,12 @@ use crate::{
|
||||||
device::Device,
|
device::Device,
|
||||||
memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices},
|
memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices},
|
||||||
};
|
};
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MemoryStore {
|
pub struct MemoryStore {
|
||||||
sessions: SessionStore,
|
sessions: SessionStore,
|
||||||
inbound_group_sessions: GroupSessionStore,
|
inbound_group_sessions: GroupSessionStore,
|
||||||
tracked_users: HashSet<UserId>,
|
tracked_users: Arc<DashSet<UserId>>,
|
||||||
users_for_key_query: HashSet<UserId>,
|
users_for_key_query: Arc<DashSet<UserId>>,
|
||||||
devices: DeviceStore,
|
devices: DeviceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,8 +40,8 @@ impl MemoryStore {
|
||||||
MemoryStore {
|
MemoryStore {
|
||||||
sessions: SessionStore::new(),
|
sessions: SessionStore::new(),
|
||||||
inbound_group_sessions: GroupSessionStore::new(),
|
inbound_group_sessions: GroupSessionStore::new(),
|
||||||
tracked_users: HashSet::new(),
|
tracked_users: Arc::new(DashSet::new()),
|
||||||
users_for_key_query: HashSet::new(),
|
users_for_key_query: Arc::new(DashSet::new()),
|
||||||
devices: DeviceStore::new(),
|
devices: DeviceStore::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -48,15 +49,15 @@ impl MemoryStore {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl CryptoStore for MemoryStore {
|
impl CryptoStore for MemoryStore {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
async fn load_account(&self) -> Result<Option<Account>> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_account(&mut self, _: Account) -> Result<()> {
|
async fn save_account(&self, _: Account) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
|
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
|
||||||
for session in sessions {
|
for session in sessions {
|
||||||
let _ = self.sessions.add(session.clone()).await;
|
let _ = self.sessions.add(session.clone()).await;
|
||||||
}
|
}
|
||||||
|
@ -64,16 +65,16 @@ impl CryptoStore for MemoryStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||||
Ok(self.sessions.get(sender_key))
|
Ok(self.sessions.get(sender_key))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
|
||||||
Ok(self.inbound_group_sessions.add(session))
|
Ok(self.inbound_group_sessions.add(session))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
&mut self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
|
@ -83,15 +84,20 @@ impl CryptoStore for MemoryStore {
|
||||||
.get(room_id, sender_key, session_id))
|
.get(room_id, sender_key, session_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tracked_users(&self) -> &HashSet<UserId> {
|
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
&self.tracked_users
|
#[allow(clippy::map_clone)]
|
||||||
|
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
||||||
&self.users_for_key_query
|
self.tracked_users.contains(user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
fn has_users_for_key_query(&self) -> bool {
|
||||||
|
!self.users_for_key_query.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||||
if dirty {
|
if dirty {
|
||||||
self.users_for_key_query.insert(user.clone());
|
self.users_for_key_query.insert(user.clone());
|
||||||
} else {
|
} else {
|
||||||
|
@ -135,7 +141,7 @@ mod test {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_session_store() {
|
async fn test_session_store() {
|
||||||
let (account, session) = get_account_and_session().await;
|
let (account, session) = get_account_and_session().await;
|
||||||
let mut store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
|
|
||||||
assert!(store.load_account().await.unwrap().is_none());
|
assert!(store.load_account().await.unwrap().is_none());
|
||||||
store.save_account(account).await.unwrap();
|
store.save_account(account).await.unwrap();
|
||||||
|
@ -168,7 +174,7 @@ mod test {
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
let _ = store
|
let _ = store
|
||||||
.save_inbound_group_session(inbound.clone())
|
.save_inbound_group_session(inbound.clone())
|
||||||
.await
|
.await
|
||||||
|
@ -217,7 +223,7 @@ mod test {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_tracked_users() {
|
async fn test_tracked_users() {
|
||||||
let device = get_device();
|
let device = get_device();
|
||||||
let mut store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
|
|
||||||
assert!(store
|
assert!(store
|
||||||
.update_tracked_user(device.user_id(), false)
|
.update_tracked_user(device.user_id(), false)
|
||||||
|
@ -228,8 +234,6 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap());
|
.unwrap());
|
||||||
|
|
||||||
let tracked_users = store.tracked_users();
|
assert!(store.is_user_tracked(device.user_id()));
|
||||||
|
|
||||||
let _ = tracked_users.contains(device.user_id());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,28 +95,28 @@ pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||||
/// keys.
|
/// keys.
|
||||||
pub trait CryptoStore: Debug {
|
pub trait CryptoStore: Debug {
|
||||||
/// Load an account that was previously stored.
|
/// Load an account that was previously stored.
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
async fn load_account(&self) -> Result<Option<Account>>;
|
||||||
|
|
||||||
/// Save the given account in the store.
|
/// Save the given account in the store.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `account` - The account that should be stored.
|
/// * `account` - The account that should be stored.
|
||||||
async fn save_account(&mut self, account: Account) -> Result<()>;
|
async fn save_account(&self, account: Account) -> Result<()>;
|
||||||
|
|
||||||
/// Save the given sessions in the store.
|
/// Save the given sessions in the store.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `session` - The sessions that should be stored.
|
/// * `session` - The sessions that should be stored.
|
||||||
async fn save_sessions(&mut self, session: &[Session]) -> Result<()>;
|
async fn save_sessions(&self, session: &[Session]) -> Result<()>;
|
||||||
|
|
||||||
/// Get all the sessions that belong to the given sender key.
|
/// Get all the sessions that belong to the given sender key.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `sender_key` - The sender key that was used to establish the sessions.
|
/// * `sender_key` - The sender key that was used to establish the sessions.
|
||||||
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
|
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
|
||||||
|
|
||||||
/// Save the given inbound group session in the store.
|
/// Save the given inbound group session in the store.
|
||||||
///
|
///
|
||||||
|
@ -126,7 +126,7 @@ pub trait CryptoStore: Debug {
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `session` - The session that should be stored.
|
/// * `session` - The session that should be stored.
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
|
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool>;
|
||||||
|
|
||||||
/// Get the inbound group session from our store.
|
/// Get the inbound group session from our store.
|
||||||
///
|
///
|
||||||
|
@ -137,18 +137,21 @@ pub trait CryptoStore: Debug {
|
||||||
///
|
///
|
||||||
/// * `session_id` - The unique id of the session.
|
/// * `session_id` - The unique id of the session.
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
&mut self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Result<Option<InboundGroupSession>>;
|
) -> Result<Option<InboundGroupSession>>;
|
||||||
|
|
||||||
/// Get the set of tracked users.
|
/// Is the given user already tracked.
|
||||||
fn tracked_users(&self) -> &HashSet<UserId>;
|
fn is_user_tracked(&self, user_id: &UserId) -> bool;
|
||||||
|
|
||||||
|
/// Are there any tracked users that are marked as dirty.
|
||||||
|
fn has_users_for_key_query(&self) -> bool;
|
||||||
|
|
||||||
/// Set of users that we need to query keys for. This is a subset of
|
/// Set of users that we need to query keys for. This is a subset of
|
||||||
/// the tracked users.
|
/// the tracked users.
|
||||||
fn users_for_key_query(&self) -> &HashSet<UserId>;
|
fn users_for_key_query(&self) -> HashSet<UserId>;
|
||||||
|
|
||||||
/// Add an user for tracking.
|
/// Add an user for tracking.
|
||||||
///
|
///
|
||||||
|
@ -159,7 +162,7 @@ pub trait CryptoStore: Debug {
|
||||||
/// * `user` - The user that should be marked as tracked.
|
/// * `user` - The user that should be marked as tracked.
|
||||||
///
|
///
|
||||||
/// * `dirty` - Should the user be also marked for a key query.
|
/// * `dirty` - Should the user be also marked for a key query.
|
||||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>;
|
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>;
|
||||||
|
|
||||||
/// Save the given devices in the store.
|
/// Save the given devices in the store.
|
||||||
///
|
///
|
||||||
|
|
|
@ -17,10 +17,11 @@ use std::{
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
result::Result as StdResult,
|
result::Result as StdResult,
|
||||||
sync::Arc,
|
sync::{Arc, Mutex as SyncMutex},
|
||||||
};
|
};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use dashmap::DashSet;
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
events::Algorithm,
|
events::Algorithm,
|
||||||
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
|
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
|
||||||
|
@ -40,22 +41,24 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
/// SQLite based implementation of a `CryptoStore`.
|
/// SQLite based implementation of a `CryptoStore`.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct SqliteStore {
|
pub struct SqliteStore {
|
||||||
user_id: Arc<UserId>,
|
user_id: Arc<UserId>,
|
||||||
device_id: Arc<Box<DeviceId>>,
|
device_id: Arc<Box<DeviceId>>,
|
||||||
account_info: Option<AccountInfo>,
|
account_info: Arc<SyncMutex<Option<AccountInfo>>>,
|
||||||
path: PathBuf,
|
path: Arc<PathBuf>,
|
||||||
|
|
||||||
sessions: SessionStore,
|
sessions: SessionStore,
|
||||||
inbound_group_sessions: GroupSessionStore,
|
inbound_group_sessions: GroupSessionStore,
|
||||||
devices: DeviceStore,
|
devices: DeviceStore,
|
||||||
tracked_users: HashSet<UserId>,
|
tracked_users: Arc<DashSet<UserId>>,
|
||||||
users_for_key_query: HashSet<UserId>,
|
users_for_key_query: Arc<DashSet<UserId>>,
|
||||||
|
|
||||||
connection: Arc<Mutex<SqliteConnection>>,
|
connection: Arc<Mutex<SqliteConnection>>,
|
||||||
pickle_passphrase: Option<Zeroizing<String>>,
|
pickle_passphrase: Arc<Option<Zeroizing<String>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
struct AccountInfo {
|
struct AccountInfo {
|
||||||
account_id: i64,
|
account_id: i64,
|
||||||
identity_keys: Arc<IdentityKeys>,
|
identity_keys: Arc<IdentityKeys>,
|
||||||
|
@ -130,22 +133,26 @@ impl SqliteStore {
|
||||||
let store = SqliteStore {
|
let store = SqliteStore {
|
||||||
user_id: Arc::new(user_id.to_owned()),
|
user_id: Arc::new(user_id.to_owned()),
|
||||||
device_id: Arc::new(device_id.into()),
|
device_id: Arc::new(device_id.into()),
|
||||||
account_info: None,
|
account_info: Arc::new(SyncMutex::new(None)),
|
||||||
sessions: SessionStore::new(),
|
sessions: SessionStore::new(),
|
||||||
inbound_group_sessions: GroupSessionStore::new(),
|
inbound_group_sessions: GroupSessionStore::new(),
|
||||||
devices: DeviceStore::new(),
|
devices: DeviceStore::new(),
|
||||||
path: path.as_ref().to_owned(),
|
path: Arc::new(path.as_ref().to_owned()),
|
||||||
connection: Arc::new(Mutex::new(connection)),
|
connection: Arc::new(Mutex::new(connection)),
|
||||||
pickle_passphrase: passphrase,
|
pickle_passphrase: Arc::new(passphrase),
|
||||||
tracked_users: HashSet::new(),
|
tracked_users: Arc::new(DashSet::new()),
|
||||||
users_for_key_query: HashSet::new(),
|
users_for_key_query: Arc::new(DashSet::new()),
|
||||||
};
|
};
|
||||||
store.create_tables().await?;
|
store.create_tables().await?;
|
||||||
Ok(store)
|
Ok(store)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn account_id(&self) -> Option<i64> {
|
fn account_id(&self) -> Option<i64> {
|
||||||
self.account_info.as_ref().map(|i| i.account_id)
|
self.account_info
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.as_ref()
|
||||||
|
.map(|i| i.account_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_tables(&self) -> Result<()> {
|
async fn create_tables(&self) -> Result<()> {
|
||||||
|
@ -299,7 +306,7 @@ impl SqliteStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn lazy_load_sessions(&mut self, sender_key: &str) -> Result<()> {
|
async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> {
|
||||||
let loaded_sessions = self.sessions.get(sender_key).is_some();
|
let loaded_sessions = self.sessions.get(sender_key).is_some();
|
||||||
|
|
||||||
if !loaded_sessions {
|
if !loaded_sessions {
|
||||||
|
@ -313,18 +320,17 @@ impl SqliteStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_sessions_for(
|
async fn get_sessions_for(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||||
&mut self,
|
|
||||||
sender_key: &str,
|
|
||||||
) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
|
||||||
self.lazy_load_sessions(sender_key).await?;
|
self.lazy_load_sessions(sender_key).await?;
|
||||||
Ok(self.sessions.get(sender_key))
|
Ok(self.sessions.get(sender_key))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Session>> {
|
async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> {
|
||||||
let account_info = self
|
let account_info = self
|
||||||
.account_info
|
.account_info
|
||||||
.as_ref()
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.clone()
|
||||||
.ok_or(CryptoStoreError::AccountUnset)?;
|
.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
@ -365,7 +371,7 @@ impl SqliteStore {
|
||||||
.collect::<Result<Vec<Session>>>()?)
|
.collect::<Result<Vec<Session>>>()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
|
async fn load_inbound_group_sessions(&self) -> Result<()> {
|
||||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
@ -377,7 +383,7 @@ impl SqliteStore {
|
||||||
.fetch_all(&mut *connection)
|
.fetch_all(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(rows
|
let mut group_sessions = rows
|
||||||
.iter()
|
.iter()
|
||||||
.map(|row| {
|
.map(|row| {
|
||||||
let pickle = &row.0;
|
let pickle = &row.0;
|
||||||
|
@ -393,7 +399,16 @@ impl SqliteStore {
|
||||||
RoomId::try_from(room_id.as_str()).unwrap(),
|
RoomId::try_from(room_id.as_str()).unwrap(),
|
||||||
)?)
|
)?)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
.collect::<Result<Vec<InboundGroupSession>>>()?;
|
||||||
|
|
||||||
|
group_sessions
|
||||||
|
.drain(..)
|
||||||
|
.map(|s| {
|
||||||
|
self.inbound_group_sessions.add(s);
|
||||||
|
})
|
||||||
|
.for_each(drop);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
|
async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
|
||||||
|
@ -417,7 +432,7 @@ impl SqliteStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
|
async fn load_tracked_users(&self) -> Result<()> {
|
||||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
@ -429,27 +444,24 @@ impl SqliteStore {
|
||||||
.fetch_all(&mut *connection)
|
.fetch_all(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut users = HashSet::new();
|
|
||||||
let mut users_for_query = HashSet::new();
|
|
||||||
|
|
||||||
for row in rows {
|
for row in rows {
|
||||||
let user_id: &str = &row.0;
|
let user_id: &str = &row.0;
|
||||||
let dirty: bool = row.1;
|
let dirty: bool = row.1;
|
||||||
|
|
||||||
if let Ok(u) = UserId::try_from(user_id) {
|
if let Ok(u) = UserId::try_from(user_id) {
|
||||||
users.insert(u.clone());
|
self.tracked_users.insert(u.clone());
|
||||||
if dirty {
|
if dirty {
|
||||||
users_for_query.insert(u);
|
self.users_for_key_query.insert(u);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((users, users_for_query))
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_devices(&self) -> Result<DeviceStore> {
|
async fn load_devices(&self) -> Result<()> {
|
||||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
@ -461,8 +473,6 @@ impl SqliteStore {
|
||||||
.fetch_all(&mut *connection)
|
.fetch_all(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let store = DeviceStore::new();
|
|
||||||
|
|
||||||
for row in rows {
|
for row in rows {
|
||||||
let device_row_id = row.0;
|
let device_row_id = row.0;
|
||||||
let user_id: &str = &row.1;
|
let user_id: &str = &row.1;
|
||||||
|
@ -555,10 +565,10 @@ impl SqliteStore {
|
||||||
signatures,
|
signatures,
|
||||||
);
|
);
|
||||||
|
|
||||||
store.add(device);
|
self.devices.add(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(store)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_device_helper(&self, device: Device) -> Result<()> {
|
async fn save_device_helper(&self, device: Device) -> Result<()> {
|
||||||
|
@ -643,7 +653,7 @@ impl SqliteStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pickle_mode(&self) -> PicklingMode {
|
fn get_pickle_mode(&self) -> PicklingMode {
|
||||||
match &self.pickle_passphrase {
|
match &*self.pickle_passphrase {
|
||||||
Some(p) => PicklingMode::Encrypted {
|
Some(p) => PicklingMode::Encrypted {
|
||||||
key: p.as_bytes().to_vec(),
|
key: p.as_bytes().to_vec(),
|
||||||
},
|
},
|
||||||
|
@ -654,7 +664,7 @@ impl SqliteStore {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl CryptoStore for SqliteStore {
|
impl CryptoStore for SqliteStore {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
async fn load_account(&self) -> Result<Option<Account>> {
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
let row: Option<(i64, String, bool, i64)> = query_as(
|
let row: Option<(i64, String, bool, i64)> = query_as(
|
||||||
|
@ -676,7 +686,7 @@ impl CryptoStore for SqliteStore {
|
||||||
&self.device_id,
|
&self.device_id,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
self.account_info = Some(AccountInfo {
|
*self.account_info.lock().unwrap() = Some(AccountInfo {
|
||||||
account_id: id,
|
account_id: id,
|
||||||
identity_keys: account.identity_keys.clone(),
|
identity_keys: account.identity_keys.clone(),
|
||||||
});
|
});
|
||||||
|
@ -688,26 +698,14 @@ impl CryptoStore for SqliteStore {
|
||||||
|
|
||||||
drop(connection);
|
drop(connection);
|
||||||
|
|
||||||
let mut group_sessions = self.load_inbound_group_sessions().await?;
|
self.load_inbound_group_sessions().await?;
|
||||||
|
self.load_devices().await?;
|
||||||
group_sessions
|
self.load_tracked_users().await?;
|
||||||
.drain(..)
|
|
||||||
.map(|s| {
|
|
||||||
self.inbound_group_sessions.add(s);
|
|
||||||
})
|
|
||||||
.for_each(drop);
|
|
||||||
|
|
||||||
let devices = self.load_devices().await?;
|
|
||||||
self.devices = devices;
|
|
||||||
|
|
||||||
let (tracked_users, users_for_query) = self.load_tracked_users().await?;
|
|
||||||
self.tracked_users = tracked_users;
|
|
||||||
self.users_for_key_query = users_for_query;
|
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_account(&mut self, account: Account) -> Result<()> {
|
async fn save_account(&self, account: Account) -> Result<()> {
|
||||||
let pickle = account.pickle(self.get_pickle_mode()).await;
|
let pickle = account.pickle(self.get_pickle_mode()).await;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
@ -735,7 +733,7 @@ impl CryptoStore for SqliteStore {
|
||||||
.fetch_one(&mut *connection)
|
.fetch_one(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
self.account_info = Some(AccountInfo {
|
*self.account_info.lock().unwrap() = Some(AccountInfo {
|
||||||
account_id: account_id.0,
|
account_id: account_id.0,
|
||||||
identity_keys: account.identity_keys.clone(),
|
identity_keys: account.identity_keys.clone(),
|
||||||
});
|
});
|
||||||
|
@ -743,7 +741,7 @@ impl CryptoStore for SqliteStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
|
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
|
||||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
|
||||||
// TODO turn this into a transaction
|
// TODO turn this into a transaction
|
||||||
|
@ -776,11 +774,11 @@ impl CryptoStore for SqliteStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||||
Ok(self.get_sessions_for(sender_key).await?)
|
Ok(self.get_sessions_for(sender_key).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
|
||||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let pickle = session.pickle(self.get_pickle_mode()).await;
|
let pickle = session.pickle(self.get_pickle_mode()).await;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
@ -808,7 +806,7 @@ impl CryptoStore for SqliteStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
&mut self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
|
@ -818,15 +816,20 @@ impl CryptoStore for SqliteStore {
|
||||||
.get(room_id, sender_key, session_id))
|
.get(room_id, sender_key, session_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tracked_users(&self) -> &HashSet<UserId> {
|
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
||||||
&self.tracked_users
|
self.tracked_users.contains(user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
fn has_users_for_key_query(&self) -> bool {
|
||||||
&self.users_for_key_query
|
!self.users_for_key_query.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
|
#[allow(clippy::map_clone)]
|
||||||
|
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||||
let already_added = self.tracked_users.insert(user.clone());
|
let already_added = self.tracked_users.insert(user.clone());
|
||||||
|
|
||||||
if dirty {
|
if dirty {
|
||||||
|
@ -878,7 +881,7 @@ impl CryptoStore for SqliteStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl std::fmt::Debug for SqliteStore {
|
impl std::fmt::Debug for SqliteStore {
|
||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
|
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
|
||||||
fmt.debug_struct("SqliteStore")
|
fmt.debug_struct("SqliteStore")
|
||||||
|
@ -931,7 +934,7 @@ mod test {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_loaded_store() -> (Account, SqliteStore, tempfile::TempDir) {
|
async fn get_loaded_store() -> (Account, SqliteStore, tempfile::TempDir) {
|
||||||
let (mut store, dir) = get_store(None).await;
|
let (store, dir) = get_store(None).await;
|
||||||
let account = get_account();
|
let account = get_account();
|
||||||
store
|
store
|
||||||
.save_account(account.clone())
|
.save_account(account.clone())
|
||||||
|
@ -999,7 +1002,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn save_account() {
|
async fn save_account() {
|
||||||
let (mut store, _dir) = get_store(None).await;
|
let (store, _dir) = get_store(None).await;
|
||||||
assert!(store.load_account().await.unwrap().is_none());
|
assert!(store.load_account().await.unwrap().is_none());
|
||||||
let account = get_account();
|
let account = get_account();
|
||||||
|
|
||||||
|
@ -1011,7 +1014,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn load_account() {
|
async fn load_account() {
|
||||||
let (mut store, _dir) = get_store(None).await;
|
let (store, _dir) = get_store(None).await;
|
||||||
let account = get_account();
|
let account = get_account();
|
||||||
|
|
||||||
store
|
store
|
||||||
|
@ -1027,7 +1030,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn load_account_with_passphrase() {
|
async fn load_account_with_passphrase() {
|
||||||
let (mut store, _dir) = get_store(Some("secret_passphrase")).await;
|
let (store, _dir) = get_store(Some("secret_passphrase")).await;
|
||||||
let account = get_account();
|
let account = get_account();
|
||||||
|
|
||||||
store
|
store
|
||||||
|
@ -1043,7 +1046,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn save_and_share_account() {
|
async fn save_and_share_account() {
|
||||||
let (mut store, _dir) = get_store(None).await;
|
let (store, _dir) = get_store(None).await;
|
||||||
let account = get_account();
|
let account = get_account();
|
||||||
|
|
||||||
store
|
store
|
||||||
|
@ -1066,7 +1069,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn save_session() {
|
async fn save_session() {
|
||||||
let (mut store, _dir) = get_store(None).await;
|
let (store, _dir) = get_store(None).await;
|
||||||
let (account, session) = get_account_and_session().await;
|
let (account, session) = get_account_and_session().await;
|
||||||
|
|
||||||
assert!(store.save_sessions(&[session.clone()]).await.is_err());
|
assert!(store.save_sessions(&[session.clone()]).await.is_err());
|
||||||
|
@ -1081,7 +1084,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn load_sessions() {
|
async fn load_sessions() {
|
||||||
let (mut store, _dir) = get_store(None).await;
|
let (store, _dir) = get_store(None).await;
|
||||||
let (account, session) = get_account_and_session().await;
|
let (account, session) = get_account_and_session().await;
|
||||||
store
|
store
|
||||||
.save_account(account.clone())
|
.save_account(account.clone())
|
||||||
|
@ -1100,7 +1103,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn add_and_save_session() {
|
async fn add_and_save_session() {
|
||||||
let (mut store, dir) = get_store(None).await;
|
let (store, dir) = get_store(None).await;
|
||||||
let (account, session) = get_account_and_session().await;
|
let (account, session) = get_account_and_session().await;
|
||||||
let sender_key = session.sender_key.to_owned();
|
let sender_key = session.sender_key.to_owned();
|
||||||
let session_id = session.session_id().to_owned();
|
let session_id = session.session_id().to_owned();
|
||||||
|
@ -1119,7 +1122,7 @@ mod test {
|
||||||
|
|
||||||
drop(store);
|
drop(store);
|
||||||
|
|
||||||
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
||||||
.await
|
.await
|
||||||
.expect("Can't create store");
|
.expect("Can't create store");
|
||||||
|
|
||||||
|
@ -1135,7 +1138,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn save_inbound_group_session() {
|
async fn save_inbound_group_session() {
|
||||||
let (account, mut store, _dir) = get_loaded_store().await;
|
let (account, store, _dir) = get_loaded_store().await;
|
||||||
|
|
||||||
let identity_keys = account.identity_keys();
|
let identity_keys = account.identity_keys();
|
||||||
let outbound_session = OlmOutboundGroupSession::new();
|
let outbound_session = OlmOutboundGroupSession::new();
|
||||||
|
@ -1155,7 +1158,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn load_inbound_group_session() {
|
async fn load_inbound_group_session() {
|
||||||
let (account, mut store, _dir) = get_loaded_store().await;
|
let (account, store, _dir) = get_loaded_store().await;
|
||||||
|
|
||||||
let identity_keys = account.identity_keys();
|
let identity_keys = account.identity_keys();
|
||||||
let outbound_session = OlmOutboundGroupSession::new();
|
let outbound_session = OlmOutboundGroupSession::new();
|
||||||
|
@ -1167,16 +1170,12 @@ mod test {
|
||||||
)
|
)
|
||||||
.expect("Can't create session");
|
.expect("Can't create session");
|
||||||
|
|
||||||
let session_id = session.session_id().to_owned();
|
|
||||||
|
|
||||||
store
|
store
|
||||||
.save_inbound_group_session(session.clone())
|
.save_inbound_group_session(session.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Can't save group session");
|
.expect("Can't save group session");
|
||||||
|
|
||||||
let sessions = store.load_inbound_group_sessions().await.unwrap();
|
store.load_inbound_group_sessions().await.unwrap();
|
||||||
|
|
||||||
assert_eq!(session_id, sessions[0].session_id());
|
|
||||||
|
|
||||||
let loaded_session = store
|
let loaded_session = store
|
||||||
.get_inbound_group_session(&session.room_id, &session.sender_key, session.session_id())
|
.get_inbound_group_session(&session.room_id, &session.sender_key, session.session_id())
|
||||||
|
@ -1188,7 +1187,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_tracked_users() {
|
async fn test_tracked_users() {
|
||||||
let (_account, mut store, dir) = get_loaded_store().await;
|
let (_account, store, dir) = get_loaded_store().await;
|
||||||
let device = get_device();
|
let device = get_device();
|
||||||
|
|
||||||
assert!(store
|
assert!(store
|
||||||
|
@ -1200,9 +1199,7 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap());
|
.unwrap());
|
||||||
|
|
||||||
let tracked_users = store.tracked_users();
|
assert!(store.is_user_tracked(device.user_id()));
|
||||||
|
|
||||||
assert!(tracked_users.contains(device.user_id()));
|
|
||||||
assert!(!store.users_for_key_query().contains(device.user_id()));
|
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||||
assert!(!store
|
assert!(!store
|
||||||
.update_tracked_user(device.user_id(), true)
|
.update_tracked_user(device.user_id(), true)
|
||||||
|
@ -1211,14 +1208,13 @@ mod test {
|
||||||
assert!(store.users_for_key_query().contains(device.user_id()));
|
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||||
drop(store);
|
drop(store);
|
||||||
|
|
||||||
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
||||||
.await
|
.await
|
||||||
.expect("Can't create store");
|
.expect("Can't create store");
|
||||||
|
|
||||||
store.load_account().await.unwrap();
|
store.load_account().await.unwrap();
|
||||||
|
|
||||||
let tracked_users = store.tracked_users();
|
assert!(store.is_user_tracked(device.user_id()));
|
||||||
assert!(tracked_users.contains(device.user_id()));
|
|
||||||
assert!(store.users_for_key_query().contains(device.user_id()));
|
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||||
|
|
||||||
store
|
store
|
||||||
|
@ -1227,7 +1223,7 @@ mod test {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(!store.users_for_key_query().contains(device.user_id()));
|
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||||
|
|
||||||
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
||||||
.await
|
.await
|
||||||
.expect("Can't create store");
|
.expect("Can't create store");
|
||||||
|
|
||||||
|
@ -1245,7 +1241,7 @@ mod test {
|
||||||
|
|
||||||
drop(store);
|
drop(store);
|
||||||
|
|
||||||
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
||||||
.await
|
.await
|
||||||
.expect("Can't create store");
|
.expect("Can't create store");
|
||||||
|
|
||||||
|
@ -1278,7 +1274,7 @@ mod test {
|
||||||
store.save_devices(&[device.clone()]).await.unwrap();
|
store.save_devices(&[device.clone()]).await.unwrap();
|
||||||
store.delete_device(device.clone()).await.unwrap();
|
store.delete_device(device.clone()).await.unwrap();
|
||||||
|
|
||||||
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
|
||||||
.await
|
.await
|
||||||
.expect("Can't create store");
|
.expect("Can't create store");
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ use matrix_sdk_common::{
|
||||||
api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest,
|
api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest,
|
||||||
events::{AnyToDeviceEvent, AnyToDeviceEventContent},
|
events::{AnyToDeviceEvent, AnyToDeviceEventContent},
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::sas::{content_to_request, Sas};
|
use super::sas::{content_to_request, Sas};
|
||||||
|
@ -31,13 +30,13 @@ use crate::{Account, CryptoStore, CryptoStoreError};
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct VerificationMachine {
|
pub struct VerificationMachine {
|
||||||
account: Account,
|
account: Account,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
verifications: Arc<DashMap<String, Sas>>,
|
verifications: Arc<DashMap<String, Sas>>,
|
||||||
outgoing_to_device_messages: Arc<DashMap<String, OwnedToDeviceRequest>>,
|
outgoing_to_device_messages: Arc<DashMap<String, OwnedToDeviceRequest>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VerificationMachine {
|
impl VerificationMachine {
|
||||||
pub(crate) fn new(account: Account, store: Arc<RwLock<Box<dyn CryptoStore>>>) -> Self {
|
pub(crate) fn new(account: Account, store: Arc<Box<dyn CryptoStore>>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
account,
|
account,
|
||||||
store,
|
store,
|
||||||
|
@ -112,8 +111,6 @@ impl VerificationMachine {
|
||||||
|
|
||||||
if let Some(d) = self
|
if let Some(d) = self
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&e.sender, &e.content.from_device)
|
.get_device(&e.sender, &e.content.from_device)
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
|
@ -179,7 +176,6 @@ mod test {
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
events::AnyToDeviceEventContent,
|
events::AnyToDeviceEventContent,
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{Sas, VerificationMachine};
|
use super::{Sas, VerificationMachine};
|
||||||
|
@ -209,21 +205,18 @@ mod test {
|
||||||
let alice = Account::new(&alice_id(), &alice_device_id());
|
let alice = Account::new(&alice_id(), &alice_device_id());
|
||||||
let bob = Account::new(&bob_id(), &bob_device_id());
|
let bob = Account::new(&bob_id(), &bob_device_id());
|
||||||
let store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
let bob_store: Arc<RwLock<Box<dyn CryptoStore>>> =
|
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||||
Arc::new(RwLock::new(Box::new(MemoryStore::new())));
|
|
||||||
|
|
||||||
let bob_device = Device::from_account(&bob).await;
|
let bob_device = Device::from_account(&bob).await;
|
||||||
let alice_device = Device::from_account(&alice).await;
|
let alice_device = Device::from_account(&alice).await;
|
||||||
|
|
||||||
store.save_devices(&[bob_device]).await.unwrap();
|
store.save_devices(&[bob_device]).await.unwrap();
|
||||||
bob_store
|
bob_store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.save_devices(&[alice_device.clone()])
|
.save_devices(&[alice_device.clone()])
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let machine = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store))));
|
let machine = VerificationMachine::new(alice, Arc::new(Box::new(store)));
|
||||||
let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store);
|
let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store);
|
||||||
machine
|
machine
|
||||||
.receive_event(&mut wrap_any_to_device_content(
|
.receive_event(&mut wrap_any_to_device_content(
|
||||||
|
@ -240,7 +233,7 @@ mod test {
|
||||||
fn create() {
|
fn create() {
|
||||||
let alice = Account::new(&alice_id(), &alice_device_id());
|
let alice = Account::new(&alice_id(), &alice_device_id());
|
||||||
let store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
let _ = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store))));
|
let _ = VerificationMachine::new(alice, Arc::new(Box::new(store)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
mod machine;
|
mod machine;
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
mod sas;
|
mod sas;
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{collections::BTreeMap, convert::TryInto};
|
use std::{collections::BTreeMap, convert::TryInto};
|
||||||
|
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
|
@ -212,8 +226,6 @@ fn extra_mac_info_send(ids: &SasIds, flow_id: &str) -> String {
|
||||||
///
|
///
|
||||||
/// * `flow_id` - The unique id that identifies this SAS verification process.
|
/// * `flow_id` - The unique id that identifies this SAS verification process.
|
||||||
///
|
///
|
||||||
/// * `we_started` - Flag signaling if the SAS process was started on our side.
|
|
||||||
///
|
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// This will panic if the public key of the other side wasn't set.
|
/// This will panic if the public key of the other side wasn't set.
|
||||||
|
|
|
@ -31,7 +31,6 @@ use matrix_sdk_common::{
|
||||||
AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
|
AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
|
||||||
},
|
},
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{Account, CryptoStore, CryptoStoreError, Device, TrustState};
|
use crate::{Account, CryptoStore, CryptoStoreError, Device, TrustState};
|
||||||
|
@ -45,7 +44,7 @@ use sas_state::{
|
||||||
/// Short authentication string object.
|
/// Short authentication string object.
|
||||||
pub struct Sas {
|
pub struct Sas {
|
||||||
inner: Arc<Mutex<InnerSas>>,
|
inner: Arc<Mutex<InnerSas>>,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
account: Account,
|
account: Account,
|
||||||
other_device: Device,
|
other_device: Device,
|
||||||
flow_id: Arc<String>,
|
flow_id: Arc<String>,
|
||||||
|
@ -100,7 +99,7 @@ impl Sas {
|
||||||
pub(crate) fn start(
|
pub(crate) fn start(
|
||||||
account: Account,
|
account: Account,
|
||||||
other_device: Device,
|
other_device: Device,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
) -> (Sas, StartEventContent) {
|
) -> (Sas, StartEventContent) {
|
||||||
let (inner, content) = InnerSas::start(account.clone(), other_device.clone());
|
let (inner, content) = InnerSas::start(account.clone(), other_device.clone());
|
||||||
let flow_id = inner.verification_flow_id();
|
let flow_id = inner.verification_flow_id();
|
||||||
|
@ -129,7 +128,7 @@ impl Sas {
|
||||||
pub(crate) fn from_start_event(
|
pub(crate) fn from_start_event(
|
||||||
account: Account,
|
account: Account,
|
||||||
other_device: Device,
|
other_device: Device,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
event: &ToDeviceEvent<StartEventContent>,
|
event: &ToDeviceEvent<StartEventContent>,
|
||||||
) -> Result<Sas, AnyToDeviceEventContent> {
|
) -> Result<Sas, AnyToDeviceEventContent> {
|
||||||
let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event)?;
|
let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event)?;
|
||||||
|
@ -184,8 +183,6 @@ impl Sas {
|
||||||
pub(crate) async fn mark_device_as_verified(&self) -> Result<bool, CryptoStoreError> {
|
pub(crate) async fn mark_device_as_verified(&self) -> Result<bool, CryptoStoreError> {
|
||||||
let device = self
|
let device = self
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(self.other_user_id(), self.other_device_id())
|
.get_device(self.other_user_id(), self.other_device_id())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -202,7 +199,7 @@ impl Sas {
|
||||||
);
|
);
|
||||||
|
|
||||||
device.set_trust_state(TrustState::Verified);
|
device.set_trust_state(TrustState::Verified);
|
||||||
self.store.read().await.save_devices(&[device]).await?;
|
self.store.save_devices(&[device]).await?;
|
||||||
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
|
@ -560,7 +557,6 @@ mod test {
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
events::{EventContent, ToDeviceEvent},
|
events::{EventContent, ToDeviceEvent},
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -685,14 +681,10 @@ mod test {
|
||||||
let bob = Account::new(&bob_id(), &bob_device_id());
|
let bob = Account::new(&bob_id(), &bob_device_id());
|
||||||
let bob_device = Device::from_account(&bob).await;
|
let bob_device = Device::from_account(&bob).await;
|
||||||
|
|
||||||
let alice_store: Arc<RwLock<Box<dyn CryptoStore>>> =
|
let alice_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||||
Arc::new(RwLock::new(Box::new(MemoryStore::new())));
|
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||||
let bob_store: Arc<RwLock<Box<dyn CryptoStore>>> =
|
|
||||||
Arc::new(RwLock::new(Box::new(MemoryStore::new())));
|
|
||||||
|
|
||||||
bob_store
|
bob_store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.save_devices(&[alice_device.clone()])
|
.save_devices(&[alice_device.clone()])
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
@ -98,6 +98,7 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl Default for AcceptedProtocols {
|
impl Default for AcceptedProtocols {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
AcceptedProtocols {
|
AcceptedProtocols {
|
||||||
|
@ -146,6 +147,7 @@ pub struct SasState<S: Clone> {
|
||||||
state: Arc<S>,
|
state: Arc<S>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(tarpaulin_include))]
|
||||||
impl<S: Clone + std::fmt::Debug> std::fmt::Debug for SasState<S> {
|
impl<S: Clone + std::fmt::Debug> std::fmt::Debug for SasState<S> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("SasState")
|
f.debug_struct("SasState")
|
||||||
|
|
|
@ -11,9 +11,9 @@ repository = "https://github.com/matrix-org/matrix-rust-sdk"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
serde_json = "1.0.56"
|
serde_json = "1.0.57"
|
||||||
http = "0.2.1"
|
http = "0.2.1"
|
||||||
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
|
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
|
||||||
matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" }
|
matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" }
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
serde = "1.0.114"
|
serde = "1.0.115"
|
||||||
|
|
Loading…
Reference in New Issue