crypto: Prepare the scaffolding for key queries and user tracking.
This commit is contained in:
parent
fdb2028dfc
commit
2020700673
7 changed files with 210 additions and 5 deletions
|
@ -13,12 +13,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures::future::{BoxFuture, Future, FutureExt};
|
||||
use std::collections::HashMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, RwLock as SyncLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use futures::future::{BoxFuture, Future, FutureExt};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::delay_for as sleep;
|
||||
use tracing::{debug, info, instrument, trace};
|
||||
|
@ -36,6 +38,9 @@ use ruma_events::EventResult;
|
|||
pub use ruma_events::EventType;
|
||||
use ruma_identifiers::RoomId;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use ruma_identifiers::{DeviceId, UserId};
|
||||
|
||||
use crate::api;
|
||||
use crate::base_client::Client as BaseClient;
|
||||
use crate::models::Room;
|
||||
|
@ -185,6 +190,8 @@ impl SyncSettings {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use api::r0::keys::get_keys;
|
||||
#[cfg(feature = "encryption")]
|
||||
use api::r0::keys::upload_keys;
|
||||
use api::r0::message::create_message_event;
|
||||
|
@ -658,6 +665,11 @@ impl AsyncClient {
|
|||
if self.base_client.read().await.should_upload_keys().await {
|
||||
let _ = self.keys_upload().await;
|
||||
}
|
||||
|
||||
if self.base_client.read().await.should_query_keys().await {
|
||||
// TODO enable this
|
||||
// let _ = self.keys_query().await;
|
||||
}
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
|
@ -829,4 +841,48 @@ impl AsyncClient {
|
|||
pub async fn sync_token(&self) -> Option<String> {
|
||||
self.base_client.read().await.sync_token.clone()
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
||||
#[instrument]
|
||||
/// Query the server for users device keys.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if no key query needs to be done.
|
||||
async fn keys_query(&self) -> Result<get_keys::Response> {
|
||||
let mut users_for_query = self
|
||||
.base_client
|
||||
.read()
|
||||
.await
|
||||
.users_for_key_query()
|
||||
.await
|
||||
.expect("Keys don't need to be uploaded");
|
||||
|
||||
debug!(
|
||||
"Querying device keys device for users: {:?}",
|
||||
users_for_query
|
||||
);
|
||||
|
||||
let mut device_keys: HashMap<UserId, Vec<DeviceId>> = HashMap::new();
|
||||
|
||||
for user in users_for_query.drain() {
|
||||
device_keys.insert(UserId::try_from(user.as_ref()).unwrap(), Vec::new());
|
||||
}
|
||||
|
||||
let request = get_keys::Request {
|
||||
timeout: None,
|
||||
device_keys,
|
||||
token: None,
|
||||
};
|
||||
|
||||
let response = self.send(request).await?;
|
||||
self.base_client
|
||||
.write()
|
||||
.await
|
||||
.receive_keys_query_response(&response)
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
|
@ -47,7 +47,10 @@ use tokio::sync::Mutex;
|
|||
#[cfg(feature = "encryption")]
|
||||
use crate::crypto::{OlmMachine, OneTimeKeys};
|
||||
#[cfg(feature = "encryption")]
|
||||
use ruma_client_api::r0::keys::{upload_keys::Response as KeysUploadResponse, DeviceKeys};
|
||||
use ruma_client_api::r0::keys::{
|
||||
get_keys::Response as KeysQueryResponse, upload_keys::Response as KeysUploadResponse,
|
||||
DeviceKeys,
|
||||
};
|
||||
use ruma_identifiers::RoomId;
|
||||
|
||||
pub type Token = String;
|
||||
|
@ -378,6 +381,18 @@ impl Client {
|
|||
|
||||
if let Some(o) = &mut *olm {
|
||||
o.receive_sync_response(response).await;
|
||||
|
||||
// TODO once the base client deals with callbacks move this into the
|
||||
// part where we already iterate through the rooms to avoid yet
|
||||
// another room loop.
|
||||
for room in self.joined_rooms.values() {
|
||||
let room = room.read().unwrap();
|
||||
if !room.is_encrypted() {
|
||||
continue;
|
||||
}
|
||||
|
||||
o.update_tracked_users(room.members.keys()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -394,6 +409,18 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
/// Should users be queried for their device keys.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
||||
pub async fn should_query_keys(&self) -> bool {
|
||||
let olm = self.olm.lock().await;
|
||||
|
||||
match &*olm {
|
||||
Some(o) => o.should_query_keys(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a tuple of device and one-time keys that need to be uploaded.
|
||||
///
|
||||
/// Returns an empty error if no keys need to be uploaded.
|
||||
|
@ -410,6 +437,20 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get the users that we need to query keys for.
|
||||
///
|
||||
/// Returns an empty error if no keys need to be queried.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
||||
pub async fn users_for_key_query(&self) -> StdResult<HashSet<String>, ()> {
|
||||
let olm = self.olm.lock().await;
|
||||
|
||||
match &*olm {
|
||||
Some(o) => Ok(o.users_for_key_query()),
|
||||
None => Err(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive a successful keys upload response.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -428,6 +469,26 @@ impl Client {
|
|||
o.receive_keys_upload_response(response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive a successful keys query response.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `response` - The keys query response of the request that the client
|
||||
/// performed.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the client hasn't been logged in.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
||||
pub async fn receive_keys_query_response(&self, response: &KeysQueryResponse) -> Result<()> {
|
||||
let mut olm = self.olm.lock().await;
|
||||
|
||||
let o = olm.as_mut().expect("Client isn't logged in.");
|
||||
o.receive_keys_query_response(response).await?;
|
||||
// TODO notify our callers of new devices via some callback.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::convert::TryInto;
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
use std::path::Path;
|
||||
|
@ -70,6 +70,9 @@ pub struct OlmMachine {
|
|||
/// Persists all the encrytpion keys so a client can resume the session
|
||||
/// without the need to create new keys.
|
||||
store: Box<dyn CryptoStore>,
|
||||
/// Set of users that we need to query keys for. This is a subset of
|
||||
/// the tracked users in the CryptoStore.
|
||||
users_for_key_query: HashSet<String>,
|
||||
}
|
||||
|
||||
impl OlmMachine {
|
||||
|
@ -86,6 +89,7 @@ impl OlmMachine {
|
|||
account: Arc::new(Mutex::new(Account::new())),
|
||||
uploaded_signed_key_count: None,
|
||||
store: Box::new(MemoryStore::new()),
|
||||
users_for_key_query: HashSet::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -112,12 +116,14 @@ impl OlmMachine {
|
|||
}
|
||||
};
|
||||
|
||||
// TODO load the tracked users here.
|
||||
Ok(OlmMachine {
|
||||
user_id: user_id.clone(),
|
||||
device_id: device_id.to_owned(),
|
||||
account: Arc::new(Mutex::new(account)),
|
||||
uploaded_signed_key_count: None,
|
||||
store: Box::new(store),
|
||||
users_for_key_query: HashSet::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -176,6 +182,21 @@ impl OlmMachine {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive a successful keys query response.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `response` - The keys query response of the request that the client
|
||||
/// performed.
|
||||
// TODO this should return a
|
||||
#[instrument]
|
||||
pub async fn receive_keys_query_response(
|
||||
&mut self,
|
||||
response: &keys::get_keys::Response,
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
/// Generate new one-time keys.
|
||||
///
|
||||
/// Returns the number of newly generated one-time keys. If no keys can be
|
||||
|
@ -680,6 +701,44 @@ impl OlmMachine {
|
|||
|
||||
Ok(decrypted_event)
|
||||
}
|
||||
|
||||
/// Update the tracked users.
|
||||
///
|
||||
/// This will only not already seen users for a key query and user tracking.
|
||||
/// If the user is already known to the Olm machine it will not be
|
||||
/// considered for a key query.
|
||||
///
|
||||
/// Use the `mark_user_as_changed()` if the user really needs a key query.
|
||||
pub async fn update_tracked_users<'a, I>(&mut self, users: I)
|
||||
where
|
||||
I: IntoIterator<Item = &'a String>,
|
||||
{
|
||||
for user in users {
|
||||
let ret = self.store.add_user_for_tracking(user).await;
|
||||
|
||||
match ret {
|
||||
Ok(newly_added) => {
|
||||
if newly_added {
|
||||
self.users_for_key_query.insert(user.to_string());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error storing users for tracking {}", e);
|
||||
self.users_for_key_query.insert(user.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Should a key query be done.
|
||||
pub fn should_query_keys(&self) -> bool {
|
||||
!self.users_for_key_query.is_empty()
|
||||
}
|
||||
|
||||
/// Get the set of users that we need to query keys for.
|
||||
pub fn users_for_key_query(&self) -> HashSet<String> {
|
||||
self.users_for_key_query.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
@ -24,6 +25,7 @@ use crate::crypto::memory_stores::{GroupSessionStore, SessionStore};
|
|||
pub struct MemoryStore {
|
||||
sessions: SessionStore,
|
||||
inbound_group_sessions: GroupSessionStore,
|
||||
tracked_users: HashSet<String>,
|
||||
}
|
||||
|
||||
impl MemoryStore {
|
||||
|
@ -31,6 +33,7 @@ impl MemoryStore {
|
|||
MemoryStore {
|
||||
sessions: SessionStore::new(),
|
||||
inbound_group_sessions: GroupSessionStore::new(),
|
||||
tracked_users: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -75,4 +78,12 @@ impl CryptoStore for MemoryStore {
|
|||
.inbound_group_sessions
|
||||
.get(room_id, sender_key, session_id))
|
||||
}
|
||||
|
||||
fn tracked_users(&self) -> &HashSet<String> {
|
||||
&self.tracked_users
|
||||
}
|
||||
|
||||
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
|
||||
Ok(self.tracked_users.insert(user.to_string()))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
use core::fmt::Debug;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io::Error as IoError;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
@ -79,4 +79,6 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
sender_key: &str,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
|
||||
fn tracked_users(&self) -> &HashSet<String>;
|
||||
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool>;
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
@ -37,6 +38,7 @@ pub struct SqliteStore {
|
|||
inbound_group_sessions: GroupSessionStore,
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
pickle_passphrase: Option<Zeroizing<String>>,
|
||||
tracked_users: HashSet<String>,
|
||||
}
|
||||
|
||||
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
|
||||
|
@ -83,6 +85,7 @@ impl SqliteStore {
|
|||
path: path.as_ref().to_owned(),
|
||||
connection: Arc::new(Mutex::new(connection)),
|
||||
pickle_passphrase: passphrase,
|
||||
tracked_users: HashSet::new(),
|
||||
};
|
||||
store.create_tables().await?;
|
||||
Ok(store)
|
||||
|
@ -397,6 +400,14 @@ impl CryptoStore for SqliteStore {
|
|||
.inbound_group_sessions
|
||||
.get(room_id, sender_key, session_id))
|
||||
}
|
||||
|
||||
fn tracked_users(&self) -> &HashSet<String> {
|
||||
&self.tracked_users
|
||||
}
|
||||
|
||||
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SqliteStore {
|
||||
|
|
|
@ -149,6 +149,11 @@ impl Room {
|
|||
self.room_name.calculate_name(&self.room_id, &self.members)
|
||||
}
|
||||
|
||||
/// Is the room a encrypted room.
|
||||
pub fn is_encrypted(&self) -> bool {
|
||||
self.encrypted
|
||||
}
|
||||
|
||||
fn add_member(&mut self, event: &MemberEvent) -> bool {
|
||||
if self.members.contains_key(&event.state_key) {
|
||||
return false;
|
||||
|
|
Loading…
Reference in a new issue