diff --git a/src/async_client.rs b/src/async_client.rs index 43e2c4f3..48aac0b0 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -13,12 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use futures::future::{BoxFuture, Future, FutureExt}; +use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::result::Result as StdResult; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex, RwLock as SyncLock}; use std::time::{Duration, Instant}; + +use futures::future::{BoxFuture, Future, FutureExt}; use tokio::sync::RwLock; use tokio::time::delay_for as sleep; use tracing::{debug, info, instrument, trace}; @@ -36,6 +38,9 @@ use ruma_events::EventResult; pub use ruma_events::EventType; use ruma_identifiers::RoomId; +#[cfg(feature = "encryption")] +use ruma_identifiers::{DeviceId, UserId}; + use crate::api; use crate::base_client::Client as BaseClient; use crate::models::Room; @@ -185,6 +190,8 @@ impl SyncSettings { } } +#[cfg(feature = "encryption")] +use api::r0::keys::get_keys; #[cfg(feature = "encryption")] use api::r0::keys::upload_keys; use api::r0::message::create_message_event; @@ -658,6 +665,11 @@ impl AsyncClient { if self.base_client.read().await.should_upload_keys().await { let _ = self.keys_upload().await; } + + if self.base_client.read().await.should_query_keys().await { + // TODO enable this + // let _ = self.keys_query().await; + } } let now = Instant::now(); @@ -829,4 +841,48 @@ impl AsyncClient { pub async fn sync_token(&self) -> Option { self.base_client.read().await.sync_token.clone() } + + #[cfg(feature = "encryption")] + #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] + #[instrument] + /// Query the server for users device keys. + /// + /// # Panics + /// + /// Panics if no key query needs to be done. + async fn keys_query(&self) -> Result { + let mut users_for_query = self + .base_client + .read() + .await + .users_for_key_query() + .await + .expect("Keys don't need to be uploaded"); + + debug!( + "Querying device keys device for users: {:?}", + users_for_query + ); + + let mut device_keys: HashMap> = HashMap::new(); + + for user in users_for_query.drain() { + device_keys.insert(UserId::try_from(user.as_ref()).unwrap(), Vec::new()); + } + + let request = get_keys::Request { + timeout: None, + device_keys, + token: None, + }; + + let response = self.send(request).await?; + self.base_client + .write() + .await + .receive_keys_query_response(&response) + .await?; + + Ok(response) + } } diff --git a/src/base_client.rs b/src/base_client.rs index b910d53d..1572356e 100644 --- a/src/base_client.rs +++ b/src/base_client.rs @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; #[cfg(feature = "encryption")] @@ -47,7 +47,10 @@ use tokio::sync::Mutex; #[cfg(feature = "encryption")] use crate::crypto::{OlmMachine, OneTimeKeys}; #[cfg(feature = "encryption")] -use ruma_client_api::r0::keys::{upload_keys::Response as KeysUploadResponse, DeviceKeys}; +use ruma_client_api::r0::keys::{ + get_keys::Response as KeysQueryResponse, upload_keys::Response as KeysUploadResponse, + DeviceKeys, +}; use ruma_identifiers::RoomId; pub type Token = String; @@ -378,6 +381,18 @@ impl Client { if let Some(o) = &mut *olm { o.receive_sync_response(response).await; + + // TODO once the base client deals with callbacks move this into the + // part where we already iterate through the rooms to avoid yet + // another room loop. + for room in self.joined_rooms.values() { + let room = room.read().unwrap(); + if !room.is_encrypted() { + continue; + } + + o.update_tracked_users(room.members.keys()).await; + } } } } @@ -394,6 +409,18 @@ impl Client { } } + /// Should users be queried for their device keys. + #[cfg(feature = "encryption")] + #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] + pub async fn should_query_keys(&self) -> bool { + let olm = self.olm.lock().await; + + match &*olm { + Some(o) => o.should_query_keys(), + None => false, + } + } + /// Get a tuple of device and one-time keys that need to be uploaded. /// /// Returns an empty error if no keys need to be uploaded. @@ -410,6 +437,20 @@ impl Client { } } + /// Get the users that we need to query keys for. + /// + /// Returns an empty error if no keys need to be queried. + #[cfg(feature = "encryption")] + #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] + pub async fn users_for_key_query(&self) -> StdResult, ()> { + let olm = self.olm.lock().await; + + match &*olm { + Some(o) => Ok(o.users_for_key_query()), + None => Err(()), + } + } + /// Receive a successful keys upload response. /// /// # Arguments @@ -428,6 +469,26 @@ impl Client { o.receive_keys_upload_response(response).await?; Ok(()) } + + /// Receive a successful keys query response. + /// + /// # Arguments + /// + /// * `response` - The keys query response of the request that the client + /// performed. + /// + /// # Panics + /// Panics if the client hasn't been logged in. + #[cfg(feature = "encryption")] + #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] + pub async fn receive_keys_query_response(&self, response: &KeysQueryResponse) -> Result<()> { + let mut olm = self.olm.lock().await; + + let o = olm.as_mut().expect("Client isn't logged in."); + o.receive_keys_query_response(response).await?; + // TODO notify our callers of new devices via some callback. + Ok(()) + } } #[cfg(test)] diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 73d8735d..ba975101 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::convert::TryInto; #[cfg(feature = "sqlite-cryptostore")] use std::path::Path; @@ -70,6 +70,9 @@ pub struct OlmMachine { /// Persists all the encrytpion keys so a client can resume the session /// without the need to create new keys. store: Box, + /// Set of users that we need to query keys for. This is a subset of + /// the tracked users in the CryptoStore. + users_for_key_query: HashSet, } impl OlmMachine { @@ -86,6 +89,7 @@ impl OlmMachine { account: Arc::new(Mutex::new(Account::new())), uploaded_signed_key_count: None, store: Box::new(MemoryStore::new()), + users_for_key_query: HashSet::new(), }) } @@ -112,12 +116,14 @@ impl OlmMachine { } }; + // TODO load the tracked users here. Ok(OlmMachine { user_id: user_id.clone(), device_id: device_id.to_owned(), account: Arc::new(Mutex::new(account)), uploaded_signed_key_count: None, store: Box::new(store), + users_for_key_query: HashSet::new(), }) } @@ -176,6 +182,21 @@ impl OlmMachine { Ok(()) } + /// Receive a successful keys query response. + /// + /// # Arguments + /// + /// * `response` - The keys query response of the request that the client + /// performed. + // TODO this should return a + #[instrument] + pub async fn receive_keys_query_response( + &mut self, + response: &keys::get_keys::Response, + ) -> Result<()> { + todo!() + } + /// Generate new one-time keys. /// /// Returns the number of newly generated one-time keys. If no keys can be @@ -680,6 +701,44 @@ impl OlmMachine { Ok(decrypted_event) } + + /// Update the tracked users. + /// + /// This will only not already seen users for a key query and user tracking. + /// If the user is already known to the Olm machine it will not be + /// considered for a key query. + /// + /// Use the `mark_user_as_changed()` if the user really needs a key query. + pub async fn update_tracked_users<'a, I>(&mut self, users: I) + where + I: IntoIterator, + { + for user in users { + let ret = self.store.add_user_for_tracking(user).await; + + match ret { + Ok(newly_added) => { + if newly_added { + self.users_for_key_query.insert(user.to_string()); + } + } + Err(e) => { + warn!("Error storing users for tracking {}", e); + self.users_for_key_query.insert(user.to_string()); + } + } + } + } + + /// Should a key query be done. + pub fn should_query_keys(&self) -> bool { + !self.users_for_key_query.is_empty() + } + + /// Get the set of users that we need to query keys for. + pub fn users_for_key_query(&self) -> HashSet { + self.users_for_key_query.clone() + } } #[cfg(test)] diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 9c226096..34540748 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; @@ -24,6 +25,7 @@ use crate::crypto::memory_stores::{GroupSessionStore, SessionStore}; pub struct MemoryStore { sessions: SessionStore, inbound_group_sessions: GroupSessionStore, + tracked_users: HashSet, } impl MemoryStore { @@ -31,6 +33,7 @@ impl MemoryStore { MemoryStore { sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), + tracked_users: HashSet::new(), } } } @@ -75,4 +78,12 @@ impl CryptoStore for MemoryStore { .inbound_group_sessions .get(room_id, sender_key, session_id)) } + + fn tracked_users(&self) -> &HashSet { + &self.tracked_users + } + + async fn add_user_for_tracking(&mut self, user: &str) -> Result { + Ok(self.tracked_users.insert(user.to_string())) + } } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index f2ee300a..3279aa74 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. use core::fmt::Debug; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::io::Error as IoError; use std::result::Result as StdResult; use std::sync::Arc; @@ -79,4 +79,6 @@ pub trait CryptoStore: Debug + Send + Sync { sender_key: &str, session_id: &str, ) -> Result>>>; + fn tracked_users(&self) -> &HashSet; + async fn add_user_for_tracking(&mut self, user: &str) -> Result; } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 01bfdda0..ca39c4de 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::result::Result as StdResult; use std::sync::Arc; @@ -37,6 +38,7 @@ pub struct SqliteStore { inbound_group_sessions: GroupSessionStore, connection: Arc>, pickle_passphrase: Option>, + tracked_users: HashSet, } static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; @@ -83,6 +85,7 @@ impl SqliteStore { path: path.as_ref().to_owned(), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: passphrase, + tracked_users: HashSet::new(), }; store.create_tables().await?; Ok(store) @@ -397,6 +400,14 @@ impl CryptoStore for SqliteStore { .inbound_group_sessions .get(room_id, sender_key, session_id)) } + + fn tracked_users(&self) -> &HashSet { + &self.tracked_users + } + + async fn add_user_for_tracking(&mut self, user: &str) -> Result { + todo!() + } } impl std::fmt::Debug for SqliteStore { diff --git a/src/models/room.rs b/src/models/room.rs index 059ef383..4f609d5f 100644 --- a/src/models/room.rs +++ b/src/models/room.rs @@ -149,6 +149,11 @@ impl Room { self.room_name.calculate_name(&self.room_id, &self.members) } + /// Is the room a encrypted room. + pub fn is_encrypted(&self) -> bool { + self.encrypted + } + fn add_member(&mut self, event: &MemberEvent) -> bool { if self.members.contains_key(&event.state_key) { return false;