crypto: Prepare the scaffolding for key queries and user tracking.

master
Damir Jelić 2020-04-01 15:37:00 +02:00
parent fdb2028dfc
commit 2020700673
7 changed files with 210 additions and 5 deletions

View File

@ -13,12 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::future::{BoxFuture, Future, FutureExt};
use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
use std::result::Result as StdResult;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock as SyncLock};
use std::time::{Duration, Instant};
use futures::future::{BoxFuture, Future, FutureExt};
use tokio::sync::RwLock;
use tokio::time::delay_for as sleep;
use tracing::{debug, info, instrument, trace};
@ -36,6 +38,9 @@ use ruma_events::EventResult;
pub use ruma_events::EventType;
use ruma_identifiers::RoomId;
#[cfg(feature = "encryption")]
use ruma_identifiers::{DeviceId, UserId};
use crate::api;
use crate::base_client::Client as BaseClient;
use crate::models::Room;
@ -185,6 +190,8 @@ impl SyncSettings {
}
}
#[cfg(feature = "encryption")]
use api::r0::keys::get_keys;
#[cfg(feature = "encryption")]
use api::r0::keys::upload_keys;
use api::r0::message::create_message_event;
@ -658,6 +665,11 @@ impl AsyncClient {
if self.base_client.read().await.should_upload_keys().await {
let _ = self.keys_upload().await;
}
if self.base_client.read().await.should_query_keys().await {
// TODO enable this
// let _ = self.keys_query().await;
}
}
let now = Instant::now();
@ -829,4 +841,48 @@ impl AsyncClient {
pub async fn sync_token(&self) -> Option<String> {
self.base_client.read().await.sync_token.clone()
}
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[instrument]
/// Query the server for users device keys.
///
/// # Panics
///
/// Panics if no key query needs to be done.
async fn keys_query(&self) -> Result<get_keys::Response> {
let mut users_for_query = self
.base_client
.read()
.await
.users_for_key_query()
.await
.expect("Keys don't need to be uploaded");
debug!(
"Querying device keys device for users: {:?}",
users_for_query
);
let mut device_keys: HashMap<UserId, Vec<DeviceId>> = HashMap::new();
for user in users_for_query.drain() {
device_keys.insert(UserId::try_from(user.as_ref()).unwrap(), Vec::new());
}
let request = get_keys::Request {
timeout: None,
device_keys,
token: None,
};
let response = self.send(request).await?;
self.base_client
.write()
.await
.receive_keys_query_response(&response)
.await?;
Ok(response)
}
}

View File

@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
#[cfg(feature = "encryption")]
@ -47,7 +47,10 @@ use tokio::sync::Mutex;
#[cfg(feature = "encryption")]
use crate::crypto::{OlmMachine, OneTimeKeys};
#[cfg(feature = "encryption")]
use ruma_client_api::r0::keys::{upload_keys::Response as KeysUploadResponse, DeviceKeys};
use ruma_client_api::r0::keys::{
get_keys::Response as KeysQueryResponse, upload_keys::Response as KeysUploadResponse,
DeviceKeys,
};
use ruma_identifiers::RoomId;
pub type Token = String;
@ -378,6 +381,18 @@ impl Client {
if let Some(o) = &mut *olm {
o.receive_sync_response(response).await;
// TODO once the base client deals with callbacks move this into the
// part where we already iterate through the rooms to avoid yet
// another room loop.
for room in self.joined_rooms.values() {
let room = room.read().unwrap();
if !room.is_encrypted() {
continue;
}
o.update_tracked_users(room.members.keys()).await;
}
}
}
}
@ -394,6 +409,18 @@ impl Client {
}
}
/// Should users be queried for their device keys.
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn should_query_keys(&self) -> bool {
let olm = self.olm.lock().await;
match &*olm {
Some(o) => o.should_query_keys(),
None => false,
}
}
/// Get a tuple of device and one-time keys that need to be uploaded.
///
/// Returns an empty error if no keys need to be uploaded.
@ -410,6 +437,20 @@ impl Client {
}
}
/// Get the users that we need to query keys for.
///
/// Returns an empty error if no keys need to be queried.
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn users_for_key_query(&self) -> StdResult<HashSet<String>, ()> {
let olm = self.olm.lock().await;
match &*olm {
Some(o) => Ok(o.users_for_key_query()),
None => Err(()),
}
}
/// Receive a successful keys upload response.
///
/// # Arguments
@ -428,6 +469,26 @@ impl Client {
o.receive_keys_upload_response(response).await?;
Ok(())
}
/// Receive a successful keys query response.
///
/// # Arguments
///
/// * `response` - The keys query response of the request that the client
/// performed.
///
/// # Panics
/// Panics if the client hasn't been logged in.
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn receive_keys_query_response(&self, response: &KeysQueryResponse) -> Result<()> {
let mut olm = self.olm.lock().await;
let o = olm.as_mut().expect("Client isn't logged in.");
o.receive_keys_query_response(response).await?;
// TODO notify our callers of new devices via some callback.
Ok(())
}
}
#[cfg(test)]

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::convert::TryInto;
#[cfg(feature = "sqlite-cryptostore")]
use std::path::Path;
@ -70,6 +70,9 @@ pub struct OlmMachine {
/// Persists all the encrytpion keys so a client can resume the session
/// without the need to create new keys.
store: Box<dyn CryptoStore>,
/// Set of users that we need to query keys for. This is a subset of
/// the tracked users in the CryptoStore.
users_for_key_query: HashSet<String>,
}
impl OlmMachine {
@ -86,6 +89,7 @@ impl OlmMachine {
account: Arc::new(Mutex::new(Account::new())),
uploaded_signed_key_count: None,
store: Box::new(MemoryStore::new()),
users_for_key_query: HashSet::new(),
})
}
@ -112,12 +116,14 @@ impl OlmMachine {
}
};
// TODO load the tracked users here.
Ok(OlmMachine {
user_id: user_id.clone(),
device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(account)),
uploaded_signed_key_count: None,
store: Box::new(store),
users_for_key_query: HashSet::new(),
})
}
@ -176,6 +182,21 @@ impl OlmMachine {
Ok(())
}
/// Receive a successful keys query response.
///
/// # Arguments
///
/// * `response` - The keys query response of the request that the client
/// performed.
// TODO this should return a
#[instrument]
pub async fn receive_keys_query_response(
&mut self,
response: &keys::get_keys::Response,
) -> Result<()> {
todo!()
}
/// Generate new one-time keys.
///
/// Returns the number of newly generated one-time keys. If no keys can be
@ -680,6 +701,44 @@ impl OlmMachine {
Ok(decrypted_event)
}
/// Update the tracked users.
///
/// This will only not already seen users for a key query and user tracking.
/// If the user is already known to the Olm machine it will not be
/// considered for a key query.
///
/// Use the `mark_user_as_changed()` if the user really needs a key query.
pub async fn update_tracked_users<'a, I>(&mut self, users: I)
where
I: IntoIterator<Item = &'a String>,
{
for user in users {
let ret = self.store.add_user_for_tracking(user).await;
match ret {
Ok(newly_added) => {
if newly_added {
self.users_for_key_query.insert(user.to_string());
}
}
Err(e) => {
warn!("Error storing users for tracking {}", e);
self.users_for_key_query.insert(user.to_string());
}
}
}
}
/// Should a key query be done.
pub fn should_query_keys(&self) -> bool {
!self.users_for_key_query.is_empty()
}
/// Get the set of users that we need to query keys for.
pub fn users_for_key_query(&self) -> HashSet<String> {
self.users_for_key_query.clone()
}
}
#[cfg(test)]

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
@ -24,6 +25,7 @@ use crate::crypto::memory_stores::{GroupSessionStore, SessionStore};
pub struct MemoryStore {
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
tracked_users: HashSet<String>,
}
impl MemoryStore {
@ -31,6 +33,7 @@ impl MemoryStore {
MemoryStore {
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
tracked_users: HashSet::new(),
}
}
}
@ -75,4 +78,12 @@ impl CryptoStore for MemoryStore {
.inbound_group_sessions
.get(room_id, sender_key, session_id))
}
fn tracked_users(&self) -> &HashSet<String> {
&self.tracked_users
}
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
Ok(self.tracked_users.insert(user.to_string()))
}
}

View File

@ -13,7 +13,7 @@
// limitations under the License.
use core::fmt::Debug;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::io::Error as IoError;
use std::result::Result as StdResult;
use std::sync::Arc;
@ -79,4 +79,6 @@ pub trait CryptoStore: Debug + Send + Sync {
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
fn tracked_users(&self) -> &HashSet<String>;
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool>;
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::result::Result as StdResult;
use std::sync::Arc;
@ -37,6 +38,7 @@ pub struct SqliteStore {
inbound_group_sessions: GroupSessionStore,
connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>,
tracked_users: HashSet<String>,
}
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
@ -83,6 +85,7 @@ impl SqliteStore {
path: path.as_ref().to_owned(),
connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase,
tracked_users: HashSet::new(),
};
store.create_tables().await?;
Ok(store)
@ -397,6 +400,14 @@ impl CryptoStore for SqliteStore {
.inbound_group_sessions
.get(room_id, sender_key, session_id))
}
fn tracked_users(&self) -> &HashSet<String> {
&self.tracked_users
}
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
todo!()
}
}
impl std::fmt::Debug for SqliteStore {

View File

@ -149,6 +149,11 @@ impl Room {
self.room_name.calculate_name(&self.room_id, &self.members)
}
/// Is the room a encrypted room.
pub fn is_encrypted(&self) -> bool {
self.encrypted
}
fn add_member(&mut self, event: &MemberEvent) -> bool {
if self.members.contains_key(&event.state_key) {
return false;