Merge branch 'master' into state-store

This commit is contained in:
Devin R 2020-04-27 16:45:57 -04:00
commit c91263eb13
4 changed files with 141 additions and 18 deletions

View file

@ -1008,7 +1008,7 @@ impl AsyncClient {
.read()
.await
.get_missing_sessions(users)
.await
.await?
};
if !missing_sessions.is_empty() {

View file

@ -501,12 +501,12 @@ impl Client {
pub async fn get_missing_sessions(
&self,
users: impl Iterator<Item = &UserId>,
) -> BTreeMap<UserId, BTreeMap<DeviceId, KeyAlgorithm>> {
) -> Result<BTreeMap<UserId, BTreeMap<DeviceId, KeyAlgorithm>>> {
let mut olm = self.olm.lock().await;
match &mut *olm {
Some(o) => o.get_missing_sessions(users).await,
None => BTreeMap::new(),
Some(o) => Ok(o.get_missing_sessions(users).await?),
None => Ok(BTreeMap::new()),
}
}

View file

@ -13,12 +13,16 @@
// limitations under the License.
use std::collections::BTreeMap;
#[cfg(test)]
use std::convert::TryFrom;
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use atomic::Atomic;
#[cfg(test)]
use super::OlmMachine;
use crate::api::r0::keys::{DeviceKeys, KeyAlgorithm};
use crate::events::Algorithm;
use crate::identifiers::{DeviceId, UserId};
@ -151,6 +155,36 @@ impl Device {
}
}
#[cfg(test)]
impl From<&OlmMachine> for Device {
fn from(machine: &OlmMachine) -> Self {
Device {
user_id: Arc::new(machine.user_id.clone()),
device_id: Arc::new(machine.device_id.clone()),
algorithms: Arc::new(vec![
Algorithm::MegolmV1AesSha2,
Algorithm::OlmV1Curve25519AesSha2,
]),
keys: Arc::new(
machine
.account
.identity_keys()
.iter()
.map(|(key, value)| {
(
KeyAlgorithm::try_from(key.as_ref()).unwrap(),
value.to_owned(),
)
})
.collect(),
),
display_name: Arc::new(None),
deleted: Arc::new(AtomicBool::new(false)),
trust_state: Arc::new(Atomic::new(TrustState::Unset)),
}
}
}
impl From<&DeviceKeys> for Device {
fn from(device_keys: &DeviceKeys) -> Self {
let mut keys = BTreeMap::new();

View file

@ -61,11 +61,11 @@ pub type OneTimeKeys = BTreeMap<AlgorithmAndDeviceId, OneTimeKey>;
pub struct OlmMachine {
/// The unique user id that owns this account.
user_id: UserId,
pub(crate) user_id: UserId,
/// The unique device id of the device that holds this account.
device_id: DeviceId,
pub(crate) device_id: DeviceId,
/// Our underlying Olm Account holding our identity keys.
account: Account,
pub(crate) account: Account,
/// The number of signed one-time keys we have uploaded to the server. If
/// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the
@ -202,11 +202,11 @@ impl OlmMachine {
pub async fn get_missing_sessions(
&mut self,
users: impl Iterator<Item = &UserId>,
) -> BTreeMap<UserId, BTreeMap<DeviceId, KeyAlgorithm>> {
) -> Result<BTreeMap<UserId, BTreeMap<DeviceId, KeyAlgorithm>>> {
let mut missing = BTreeMap::new();
for user_id in users {
let user_devices = self.store.get_user_devices(user_id).await.unwrap();
let user_devices = self.store.get_user_devices(user_id).await?;
for device in user_devices.devices() {
let sender_key = if let Some(k) = device.get_key(&KeyAlgorithm::Curve25519) {
@ -215,7 +215,7 @@ impl OlmMachine {
continue;
};
let sessions = self.store.get_sessions(sender_key).await.unwrap();
let sessions = self.store.get_sessions(sender_key).await?;
let is_missing = if let Some(sessions) = sessions {
sessions.lock().await.is_empty()
@ -237,7 +237,7 @@ impl OlmMachine {
}
}
missing
Ok(missing)
}
pub async fn receive_keys_claim_response(
@ -1366,21 +1366,32 @@ impl OlmMachine {
#[cfg(test)]
mod test {
static USER_ID: &str = "@test:example.org";
const DEVICE_ID: &str = "DEVICEID";
static USER_ID: &str = "@bob:example.org";
static DEVICE_ID: &str = "DEVICEID";
use js_int::UInt;
use std::collections::BTreeMap;
use std::convert::TryFrom;
use std::fs::File;
use std::io::prelude::*;
use ruma_identifiers::UserId;
use serde_json::json;
use crate::api::r0::keys;
use crate::crypto::machine::OlmMachine;
use crate::crypto::machine::{OlmMachine, OneTimeKeys};
use crate::crypto::Device;
use crate::identifiers::{DeviceId, UserId};
use http::Response;
fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap()
}
fn alice_device_id() -> DeviceId {
"JLAFKJWSCS".to_string()
}
fn user_id() -> UserId {
UserId::try_from(USER_ID).unwrap()
}
@ -1405,10 +1416,10 @@ mod test {
keys::get_keys::Response::try_from(data).expect("Can't parse the keys upload response")
}
async fn get_prepared_machine() -> OlmMachine {
async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(0);
let (_, _) = machine
let (_, otk) = machine
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
@ -1418,7 +1429,34 @@ mod test {
.await
.unwrap();
(machine, otk.unwrap())
}
async fn get_machine_after_query() -> (OlmMachine, OneTimeKeys) {
let (mut machine, otk) = get_prepared_machine().await;
let response = keys_query_response();
machine
.receive_keys_query_response(&response)
.await
.unwrap();
(machine, otk)
}
async fn get_machine_pair() -> (OlmMachine, OlmMachine, OneTimeKeys) {
let (bob, otk) = get_prepared_machine().await;
let alice_id = alice_id();
let alice_device = alice_device_id();
let alice = OlmMachine::new(&alice_id, &alice_device).unwrap();
let alice_deivce = Device::from(&alice);
let bob_device = Device::from(&bob);
alice.store.save_device(bob_device).await.unwrap();
bob.store.save_device(alice_deivce).await.unwrap();
(alice, bob, otk)
}
#[tokio::test]
@ -1590,7 +1628,7 @@ mod test {
#[tokio::test]
async fn test_keys_query() {
let mut machine = get_prepared_machine().await;
let (mut machine, _) = get_prepared_machine().await;
let response = keys_query_response();
let alice_id = UserId::try_from("@alice:example.org").unwrap();
let alice_device_id = "JLAFKJWSCS".to_owned();
@ -1612,4 +1650,55 @@ mod test {
assert_eq!(device.user_id(), &alice_id);
assert_eq!(device.device_id(), &alice_device_id);
}
#[tokio::test]
async fn test_missing_sessions_calculation() {
let (mut machine, _) = get_machine_after_query().await;
let alice = alice_id();
let alice_device = alice_device_id();
let missing_sessions = machine
.get_missing_sessions([alice.clone()].iter())
.await
.unwrap();
assert!(missing_sessions.contains_key(&alice));
let user_sessions = missing_sessions.get(&alice).unwrap();
assert!(user_sessions.contains_key(&alice_device));
}
#[tokio::test]
async fn test_key_claiming() {
let (mut alice_machine, bob_machine, one_time_keys) = get_machine_pair().await;
let mut bob_keys = BTreeMap::new();
let one_time_key = one_time_keys.iter().nth(0).unwrap();
let mut keys = BTreeMap::new();
keys.insert(one_time_key.0.clone(), one_time_key.1.clone());
bob_keys.insert(bob_machine.device_id.clone(), keys);
let mut one_time_keys = BTreeMap::new();
one_time_keys.insert(bob_machine.user_id.clone(), bob_keys);
let response = keys::claim_keys::Response {
failures: BTreeMap::new(),
one_time_keys,
};
alice_machine
.receive_keys_claim_response(&response)
.await
.unwrap();
let session = alice_machine
.store
.get_sessions(bob_machine.account.identity_keys().curve25519())
.await
.unwrap()
.unwrap();
assert!(!session.lock().await.is_empty())
}
}