crypto: Move the key count field into the account.
parent
b2ccb61864
commit
a7a9ac24ed
|
@ -18,7 +18,6 @@ use std::mem;
|
||||||
#[cfg(feature = "sqlite-cryptostore")]
|
#[cfg(feature = "sqlite-cryptostore")]
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::result::Result as StdResult;
|
use std::result::Result as StdResult;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
|
||||||
|
|
||||||
use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult, SignatureError};
|
use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult, SignatureError};
|
||||||
use super::olm::{
|
use super::olm::{
|
||||||
|
@ -67,11 +66,6 @@ pub struct OlmMachine {
|
||||||
device_id: DeviceId,
|
device_id: DeviceId,
|
||||||
/// Our underlying Olm Account holding our identity keys.
|
/// Our underlying Olm Account holding our identity keys.
|
||||||
account: Account,
|
account: Account,
|
||||||
/// The number of signed one-time keys we have uploaded to the server. If
|
|
||||||
/// this is None, no action will be taken. After a sync request the client
|
|
||||||
/// needs to set this for us, depending on the count we will suggest the
|
|
||||||
/// client to upload new keys.
|
|
||||||
uploaded_signed_key_count: Option<AtomicU64>,
|
|
||||||
/// Store for the encryption keys.
|
/// Store for the encryption keys.
|
||||||
/// Persists all the encryption keys so a client can resume the session
|
/// Persists all the encryption keys so a client can resume the session
|
||||||
/// without the need to create new keys.
|
/// without the need to create new keys.
|
||||||
|
@ -108,7 +102,6 @@ impl OlmMachine {
|
||||||
user_id: user_id.clone(),
|
user_id: user_id.clone(),
|
||||||
device_id: device_id.to_owned(),
|
device_id: device_id.to_owned(),
|
||||||
account: Account::new(user_id, &device_id),
|
account: Account::new(user_id, &device_id),
|
||||||
uploaded_signed_key_count: None,
|
|
||||||
store: Box::new(MemoryStore::new()),
|
store: Box::new(MemoryStore::new()),
|
||||||
outbound_group_sessions: HashMap::new(),
|
outbound_group_sessions: HashMap::new(),
|
||||||
}
|
}
|
||||||
|
@ -151,7 +144,6 @@ impl OlmMachine {
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
account,
|
account,
|
||||||
uploaded_signed_key_count: None,
|
|
||||||
store,
|
store,
|
||||||
outbound_group_sessions: HashMap::new(),
|
outbound_group_sessions: HashMap::new(),
|
||||||
})
|
})
|
||||||
|
@ -201,30 +193,24 @@ impl OlmMachine {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let count = self.account.uploaded_key_count() as u64;
|
||||||
|
|
||||||
// If we have a known key count, check that we have more than
|
// If we have a known key count, check that we have more than
|
||||||
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
|
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
|
||||||
match &self.uploaded_signed_key_count {
|
let max_keys = self.account.max_one_time_keys().await as u64;
|
||||||
Some(count) => {
|
// If there are more keys already uploaded than max_key / 2
|
||||||
let max_keys = self.account.max_one_time_keys().await as u64;
|
// bail out returning false, this also avoids overflow.
|
||||||
// If there are more keys already uploaded than max_key / 2
|
if count > (max_keys / 2) {
|
||||||
// bail out returning false, this also avoids overflow.
|
return false;
|
||||||
if count.load(Ordering::Relaxed) > (max_keys / 2) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let key_count = (max_keys / 2) - count.load(Ordering::Relaxed);
|
|
||||||
key_count > 0
|
|
||||||
}
|
|
||||||
None => false,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let key_count = (max_keys / 2) - count;
|
||||||
|
key_count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update the count of one-time keys that are currently on the server.
|
/// Update the count of one-time keys that are currently on the server.
|
||||||
fn update_key_count(&mut self, count: u64) {
|
fn update_key_count(&mut self, count: u64) {
|
||||||
match &self.uploaded_signed_key_count {
|
self.account.update_uploaded_key_count(count);
|
||||||
Some(c) => c.store(count, Ordering::Relaxed),
|
|
||||||
None => self.uploaded_signed_key_count = Some(AtomicU64::new(count)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Receive a successful keys upload response.
|
/// Receive a successful keys upload response.
|
||||||
|
@ -250,9 +236,7 @@ impl OlmMachine {
|
||||||
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
|
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
|
||||||
debug!(
|
debug!(
|
||||||
"Updated uploaded one-time key count {} -> {}, marking keys as published",
|
"Updated uploaded one-time key count {} -> {}, marking keys as published",
|
||||||
self.uploaded_signed_key_count
|
self.account.uploaded_key_count(),
|
||||||
.as_ref()
|
|
||||||
.map_or(0, |c| c.load(Ordering::Relaxed)),
|
|
||||||
count
|
count
|
||||||
);
|
);
|
||||||
self.update_key_count(count);
|
self.update_key_count(count);
|
||||||
|
@ -543,26 +527,21 @@ impl OlmMachine {
|
||||||
/// Returns the number of newly generated one-time keys. If no keys can be
|
/// Returns the number of newly generated one-time keys. If no keys can be
|
||||||
/// generated returns an empty error.
|
/// generated returns an empty error.
|
||||||
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
|
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
|
||||||
match &self.uploaded_signed_key_count {
|
let count = self.account.uploaded_key_count() as u64;
|
||||||
Some(count) => {
|
// TODO if we store the uploaded key count with the Account all
|
||||||
// TODO if we store the uploaded key count with the Account all
|
// this logic could go into the account.
|
||||||
// this logic could go into the account.
|
let max_keys = self.account.max_one_time_keys().await;
|
||||||
let count = count.load(Ordering::Relaxed);
|
let max_on_server = (max_keys as u64) / 2;
|
||||||
let max_keys = self.account.max_one_time_keys().await;
|
|
||||||
let max_on_server = (max_keys as u64) / 2;
|
|
||||||
|
|
||||||
if count >= (max_on_server) {
|
if count >= (max_on_server) {
|
||||||
return Err(());
|
return Err(());
|
||||||
}
|
|
||||||
|
|
||||||
let key_count = (max_on_server) - count;
|
|
||||||
let key_count: usize = key_count.try_into().unwrap_or(max_keys);
|
|
||||||
|
|
||||||
self.account.generate_one_time_keys(key_count).await;
|
|
||||||
Ok(key_count as u64)
|
|
||||||
}
|
|
||||||
None => Err(()),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let key_count = (max_on_server) - count;
|
||||||
|
let key_count: usize = key_count.try_into().unwrap_or(max_keys);
|
||||||
|
|
||||||
|
self.account.generate_one_time_keys(key_count).await;
|
||||||
|
Ok(key_count as u64)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate, sign and prepare one-time keys to be uploaded.
|
/// Generate, sign and prepare one-time keys to be uploaded.
|
||||||
|
@ -1023,8 +1002,6 @@ impl OlmMachine {
|
||||||
// stores the curve key of the device, if we also store the ed25519 key
|
// stores the curve key of the device, if we also store the ed25519 key
|
||||||
// with the session we'll only need to pass in the account to the
|
// with the session we'll only need to pass in the account to the
|
||||||
// session and all of this can live in the session.
|
// session and all of this can live in the session.
|
||||||
//
|
|
||||||
// Storing a reference to the account is probably not worth the effort.
|
|
||||||
|
|
||||||
let recipient_signing_key = recipient_device
|
let recipient_signing_key = recipient_device
|
||||||
.get_key(KeyAlgorithm::Ed25519)
|
.get_key(KeyAlgorithm::Ed25519)
|
||||||
|
@ -1441,7 +1418,6 @@ mod test {
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
use std::sync::atomic::AtomicU64;
|
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
|
|
||||||
use http::Response;
|
use http::Response;
|
||||||
|
@ -1514,7 +1490,7 @@ mod test {
|
||||||
|
|
||||||
async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
|
async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
|
||||||
let mut machine = OlmMachine::new(&user_id(), &alice_device_id());
|
let mut machine = OlmMachine::new(&user_id(), &alice_device_id());
|
||||||
machine.uploaded_signed_key_count = Some(AtomicU64::new(0));
|
machine.account.update_uploaded_key_count(0);
|
||||||
let (_, otk) = machine
|
let (_, otk) = machine
|
||||||
.keys_for_upload()
|
.keys_for_upload()
|
||||||
.await
|
.await
|
||||||
|
@ -1660,7 +1636,6 @@ mod test {
|
||||||
let mut response = keys_upload_response();
|
let mut response = keys_upload_response();
|
||||||
|
|
||||||
assert!(machine.should_upload_keys().await);
|
assert!(machine.should_upload_keys().await);
|
||||||
assert!(machine.generate_one_time_keys().await.is_err());
|
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.receive_keys_upload_response(&response)
|
.receive_keys_upload_response(&response)
|
||||||
|
@ -1729,8 +1704,8 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_one_time_key_signing() {
|
async fn test_one_time_key_signing() {
|
||||||
let mut machine = OlmMachine::new(&user_id(), &alice_device_id());
|
let machine = OlmMachine::new(&user_id(), &alice_device_id());
|
||||||
machine.uploaded_signed_key_count = Some(AtomicU64::new(49));
|
machine.account.update_uploaded_key_count(49);
|
||||||
|
|
||||||
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
|
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
|
||||||
let identity_keys = machine.account.identity_keys();
|
let identity_keys = machine.account.identity_keys();
|
||||||
|
@ -1750,7 +1725,7 @@ mod test {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_keys_for_upload() {
|
async fn test_keys_for_upload() {
|
||||||
let mut machine = OlmMachine::new(&user_id(), &alice_device_id());
|
let mut machine = OlmMachine::new(&user_id(), &alice_device_id());
|
||||||
machine.uploaded_signed_key_count = Some(AtomicU64::default());
|
machine.account.update_uploaded_key_count(0);
|
||||||
|
|
||||||
let identity_keys = machine.account.identity_keys();
|
let identity_keys = machine.account.identity_keys();
|
||||||
let ed25519_key = identity_keys.ed25519();
|
let ed25519_key = identity_keys.ed25519();
|
||||||
|
|
|
@ -13,9 +13,10 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use matrix_sdk_common::instant::Instant;
|
use matrix_sdk_common::instant::Instant;
|
||||||
|
use std::convert::TryFrom;
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use matrix_sdk_common::locks::Mutex;
|
use matrix_sdk_common::locks::Mutex;
|
||||||
|
@ -61,6 +62,11 @@ pub struct Account {
|
||||||
inner: Arc<Mutex<OlmAccount>>,
|
inner: Arc<Mutex<OlmAccount>>,
|
||||||
identity_keys: Arc<IdentityKeys>,
|
identity_keys: Arc<IdentityKeys>,
|
||||||
shared: Arc<AtomicBool>,
|
shared: Arc<AtomicBool>,
|
||||||
|
/// The number of signed one-time keys we have uploaded to the server. If
|
||||||
|
/// this is None, no action will be taken. After a sync request the client
|
||||||
|
/// needs to set this for us, depending on the count we will suggest the
|
||||||
|
/// client to upload new keys.
|
||||||
|
uploaded_signed_key_count: Arc<AtomicI64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg_attr(tarpaulin, skip)]
|
// #[cfg_attr(tarpaulin, skip)]
|
||||||
|
@ -90,6 +96,7 @@ impl Account {
|
||||||
inner: Arc::new(Mutex::new(account)),
|
inner: Arc::new(Mutex::new(account)),
|
||||||
identity_keys: Arc::new(identity_keys),
|
identity_keys: Arc::new(identity_keys),
|
||||||
shared: Arc::new(AtomicBool::new(false)),
|
shared: Arc::new(AtomicBool::new(false)),
|
||||||
|
uploaded_signed_key_count: Arc::new(AtomicI64::new(0)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,6 +105,22 @@ impl Account {
|
||||||
&self.identity_keys
|
&self.identity_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Update the uploaded key count.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `new_count` - The new count that was reported by the server.
|
||||||
|
pub fn update_uploaded_key_count(&self, new_count: u64) {
|
||||||
|
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
|
||||||
|
self.uploaded_signed_key_count
|
||||||
|
.store(key_count, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the currently known uploaded key count.
|
||||||
|
pub fn uploaded_key_count(&self) -> i64 {
|
||||||
|
self.uploaded_signed_key_count.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
/// Has the account been shared with the server.
|
/// Has the account been shared with the server.
|
||||||
pub fn shared(&self) -> bool {
|
pub fn shared(&self) -> bool {
|
||||||
self.shared.load(Ordering::Relaxed)
|
self.shared.load(Ordering::Relaxed)
|
||||||
|
@ -165,6 +188,7 @@ impl Account {
|
||||||
pickle: String,
|
pickle: String,
|
||||||
pickle_mode: PicklingMode,
|
pickle_mode: PicklingMode,
|
||||||
shared: bool,
|
shared: bool,
|
||||||
|
uploaded_signed_key_count: i64,
|
||||||
user_id: &UserId,
|
user_id: &UserId,
|
||||||
device_id: &DeviceId,
|
device_id: &DeviceId,
|
||||||
) -> Result<Self, OlmAccountError> {
|
) -> Result<Self, OlmAccountError> {
|
||||||
|
@ -177,6 +201,7 @@ impl Account {
|
||||||
inner: Arc::new(Mutex::new(account)),
|
inner: Arc::new(Mutex::new(account)),
|
||||||
identity_keys: Arc::new(identity_keys),
|
identity_keys: Arc::new(identity_keys),
|
||||||
shared: Arc::new(AtomicBool::from(shared)),
|
shared: Arc::new(AtomicBool::from(shared)),
|
||||||
|
uploaded_signed_key_count: Arc::new(AtomicI64::new(uploaded_signed_key_count)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,7 @@ impl SqliteStore {
|
||||||
"device_id" TEXT NOT NULL,
|
"device_id" TEXT NOT NULL,
|
||||||
"pickle" BLOB NOT NULL,
|
"pickle" BLOB NOT NULL,
|
||||||
"shared" INTEGER NOT NULL,
|
"shared" INTEGER NOT NULL,
|
||||||
|
"uploaded_key_count" INTEGER NOT NULL,
|
||||||
UNIQUE(user_id,device_id)
|
UNIQUE(user_id,device_id)
|
||||||
);
|
);
|
||||||
"#,
|
"#,
|
||||||
|
@ -564,8 +565,8 @@ impl CryptoStore for SqliteStore {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
async fn load_account(&mut self) -> Result<Option<Account>> {
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
let row: Option<(i64, String, bool)> = query_as(
|
let row: Option<(i64, String, bool, i64)> = query_as(
|
||||||
"SELECT id, pickle, shared FROM accounts
|
"SELECT id, pickle, shared, uploaded_key_count FROM accounts
|
||||||
WHERE user_id = ? and device_id = ?",
|
WHERE user_id = ? and device_id = ?",
|
||||||
)
|
)
|
||||||
.bind(self.user_id.as_str())
|
.bind(self.user_id.as_str())
|
||||||
|
@ -573,12 +574,13 @@ impl CryptoStore for SqliteStore {
|
||||||
.fetch_optional(&mut *connection)
|
.fetch_optional(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let result = if let Some((id, pickle, shared)) = row {
|
let result = if let Some((id, pickle, shared, uploaded_key_count)) = row {
|
||||||
self.account_id = Some(id);
|
self.account_id = Some(id);
|
||||||
Some(Account::from_pickle(
|
Some(Account::from_pickle(
|
||||||
pickle,
|
pickle,
|
||||||
self.get_pickle_mode(),
|
self.get_pickle_mode(),
|
||||||
shared,
|
shared,
|
||||||
|
uploaded_key_count,
|
||||||
&self.user_id,
|
&self.user_id,
|
||||||
&self.device_id,
|
&self.device_id,
|
||||||
)?)
|
)?)
|
||||||
|
@ -613,8 +615,8 @@ impl CryptoStore for SqliteStore {
|
||||||
|
|
||||||
query(
|
query(
|
||||||
"INSERT INTO accounts (
|
"INSERT INTO accounts (
|
||||||
user_id, device_id, pickle, shared
|
user_id, device_id, pickle, shared, uploaded_key_count
|
||||||
) VALUES (?1, ?2, ?3, ?4)
|
) VALUES (?1, ?2, ?3, ?4, ?5)
|
||||||
ON CONFLICT(user_id, device_id) DO UPDATE SET
|
ON CONFLICT(user_id, device_id) DO UPDATE SET
|
||||||
pickle = excluded.pickle,
|
pickle = excluded.pickle,
|
||||||
shared = excluded.shared
|
shared = excluded.shared
|
||||||
|
@ -624,6 +626,7 @@ impl CryptoStore for SqliteStore {
|
||||||
.bind(&*self.device_id.to_string())
|
.bind(&*self.device_id.to_string())
|
||||||
.bind(&pickle)
|
.bind(&pickle)
|
||||||
.bind(account.shared())
|
.bind(account.shared())
|
||||||
|
.bind(account.uploaded_key_count())
|
||||||
.execute(&mut *connection)
|
.execute(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue