crypto compiles, tests pass

This commit is contained in:
Devin R 2020-04-03 11:00:37 -04:00
parent db62c1fa25
commit eff322c0c5
9 changed files with 83 additions and 79 deletions

View file

@ -623,7 +623,7 @@ impl AsyncClient {
/// * `data` - The content of the message.
pub async fn room_send(
&mut self,
room_id: &str,
room_id: &RoomId,
data: MessageEventContent,
) -> Result<create_message_event::Response> {
#[cfg(feature = "encryption")]
@ -658,7 +658,7 @@ impl AsyncClient {
}
let request = create_message_event::Request {
room_id: RoomId::try_from(room_id).unwrap(),
room_id: room_id.clone(),
event_type: EventType::RoomMessage,
txn_id: self.transaction_id().to_string(),
data,
@ -771,7 +771,7 @@ impl AsyncClient {
let mut device_keys: HashMap<UserId, Vec<DeviceId>> = HashMap::new();
for user in users_for_query.drain() {
device_keys.insert(UserId::try_from(user.as_ref()).unwrap(), Vec::new());
device_keys.insert(user, Vec::new());
}
let request = get_keys::Request {

View file

@ -400,8 +400,8 @@ impl Client {
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn get_missing_sessions(
&self,
users: impl Iterator<Item = &String>,
) -> HashMap<RumaUserId, HashMap<DeviceId, KeyAlgorithm>> {
users: impl Iterator<Item = &UserId>,
) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> {
let mut olm = self.olm.lock().await;
match &mut *olm {
@ -431,7 +431,7 @@ impl Client {
/// Returns an empty error if no keys need to be queried.
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn users_for_key_query(&self) -> StdResult<HashSet<String>, ()> {
pub async fn users_for_key_query(&self) -> StdResult<HashSet<UserId>, ()> {
let olm = self.olm.lock().await;
match &*olm {

View file

@ -71,7 +71,7 @@ pub struct OlmMachine {
store: Box<dyn CryptoStore>,
/// Set of users that we need to query keys for. This is a subset of
/// the tracked users in the CryptoStore.
users_for_key_query: HashSet<String>,
users_for_key_query: HashSet<UserId>,
}
impl OlmMachine {
@ -101,8 +101,7 @@ impl OlmMachine {
passphrase: String,
) -> Result<Self> {
let mut store =
SqliteStore::open_with_passphrase(&user_id.to_string(), device_id, path, passphrase)
.await?;
SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?;
let account = match store.load_account().await? {
Some(a) => {
@ -183,12 +182,12 @@ impl OlmMachine {
pub async fn get_missing_sessions(
&mut self,
users: impl Iterator<Item = &String>,
users: impl Iterator<Item = &UserId>,
) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> {
let mut missing = HashMap::new();
for user_id in users {
let user_devices = self.store.get_user_devices(&user_id).await.unwrap();
let user_devices = self.store.get_user_devices(user_id).await.unwrap();
for device in user_devices.devices() {
let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) {
@ -206,12 +205,11 @@ impl OlmMachine {
};
if is_missing {
let user_id = UserId::try_from(user_id.as_ref()).unwrap();
if !missing.contains_key(&user_id) {
missing.insert(user_id.to_owned(), HashMap::new());
if !missing.contains_key(user_id) {
missing.insert(user_id.clone(), HashMap::new());
}
let user_map = missing.get_mut(&user_id).unwrap();
let user_map = missing.get_mut(user_id).unwrap();
user_map.insert(
device.device_id().to_owned(),
KeyAlgorithm::SignedCurve25519,
@ -233,7 +231,7 @@ impl OlmMachine {
for (device_id, key_map) in user_devices {
let device = if let Some(d) = self
.store
.get_device(&user_id.to_string(), device_id)
.get_device(&user_id, device_id)
.await
.expect("Can't get devices")
{
@ -346,8 +344,7 @@ impl OlmMachine {
let mut changed_devices = Vec::new();
for (user_id, device_map) in &response.device_keys {
let user_id_string = user_id.to_string();
self.users_for_key_query.remove(&user_id_string);
self.users_for_key_query.remove(&user_id);
for (device_id, device_keys) in device_map.iter() {
// We don't need our own device in the device store.
@ -393,7 +390,7 @@ impl OlmMachine {
let device = self
.store
.get_device(&user_id_string, device_id)
.get_device(&user_id, device_id)
.await
.expect("Can't load device");
@ -407,7 +404,7 @@ impl OlmMachine {
}
let current_devices: HashSet<&String> = device_map.keys().collect();
let stored_devices = self.store.get_user_devices(&user_id_string).await.unwrap();
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
let stored_devices_set: HashSet<&String> = stored_devices.keys().collect();
let deleted_devices = stored_devices_set.difference(&current_devices);
@ -767,7 +764,7 @@ impl OlmMachine {
let session = InboundGroupSession::new(
sender_key,
signing_key,
&event.content.room_id.to_string(),
&event.content.room_id,
&event.content.session_key,
)?;
self.store.save_inbound_group_session(session).await?;
@ -893,11 +890,7 @@ impl OlmMachine {
let session = self
.store
.get_inbound_group_session(
&room_id.to_string(),
&content.sender_key,
&content.session_id,
)
.get_inbound_group_session(&room_id, &content.sender_key, &content.session_id)
.await?;
// TODO check if the olm session is wedged and re-request the key.
let session = session.ok_or(OlmError::MissingSession)?;
@ -936,7 +929,7 @@ impl OlmMachine {
/// Use the `mark_user_as_changed()` if the user really needs a key query.
pub async fn update_tracked_users<'a, I>(&mut self, users: I)
where
I: IntoIterator<Item = &'a String>,
I: IntoIterator<Item = &'a UserId>,
{
for user in users {
let ret = self.store.add_user_for_tracking(user).await;
@ -944,12 +937,12 @@ impl OlmMachine {
match ret {
Ok(newly_added) => {
if newly_added {
self.users_for_key_query.insert(user.to_string());
self.users_for_key_query.insert(user.clone());
}
}
Err(e) => {
warn!("Error storing users for tracking {}", e);
self.users_for_key_query.insert(user.to_string());
self.users_for_key_query.insert(user.clone());
}
}
}
@ -961,7 +954,7 @@ impl OlmMachine {
}
/// Get the set of users that we need to query keys for.
pub fn users_for_key_query(&self) -> HashSet<String> {
pub fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query.clone()
}
}

View file

@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::Arc;
use dashmap::{DashMap, ReadOnlyView};
@ -20,6 +21,7 @@ use tokio::sync::Mutex;
use super::device::Device;
use super::olm::{InboundGroupSession, Session};
use crate::identifiers::{RoomId, UserId};
#[derive(Debug)]
pub struct SessionStore {
@ -59,7 +61,7 @@ impl SessionStore {
#[derive(Debug)]
pub struct GroupSessionStore {
entries: HashMap<String, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
entries: HashMap<RoomId, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
}
impl GroupSessionStore {
@ -89,7 +91,7 @@ impl GroupSessionStore {
pub fn get(
&self,
room_id: &str,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
) -> Option<Arc<Mutex<InboundGroupSession>>> {
@ -101,7 +103,7 @@ impl GroupSessionStore {
#[derive(Clone, Debug)]
pub struct DeviceStore {
entries: Arc<DashMap<String, DashMap<String, Device>>>,
entries: Arc<DashMap<UserId, DashMap<String, Device>>>,
}
pub struct UserDevices {
@ -130,26 +132,27 @@ impl DeviceStore {
}
pub fn add(&self, device: Device) -> bool {
if !self.entries.contains_key(device.user_id()) {
self.entries
.insert(device.user_id().to_owned(), DashMap::new());
let user_id = UserId::try_from(device.user_id()).unwrap();
if !self.entries.contains_key(&user_id) {
self.entries.insert(user_id.clone(), DashMap::new());
}
let device_map = self.entries.get_mut(device.user_id()).unwrap();
let device_map = self.entries.get_mut(&user_id).unwrap();
device_map
.insert(device.device_id().to_owned(), device)
.is_some()
}
pub fn get(&self, user_id: &str, device_id: &str) -> Option<Device> {
pub fn get(&self, user_id: &UserId, device_id: &str) -> Option<Device> {
self.entries
.get(user_id)
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
}
pub fn user_devices(&self, user_id: &str) -> UserDevices {
pub fn user_devices(&self, user_id: &UserId) -> UserDevices {
if !self.entries.contains_key(user_id) {
self.entries.insert(user_id.to_owned(), DashMap::new());
self.entries.insert(user_id.clone(), DashMap::new());
}
UserDevices {
entries: self.entries.get(user_id).unwrap().clone().into_read_only(),

View file

@ -23,6 +23,8 @@ use olm_rs::PicklingMode;
use ruma_client_api::r0::keys::SignedKey;
use crate::identifiers::{RoomId, UserId};
pub struct Account {
inner: OlmAccount,
pub(crate) shared: bool,
@ -210,7 +212,7 @@ pub struct InboundGroupSession {
inner: OlmInboundGroupSession,
pub(crate) sender_key: String,
pub(crate) signing_key: String,
pub(crate) room_id: String,
pub(crate) room_id: RoomId,
forwarding_chains: Option<Vec<String>>,
}
@ -218,14 +220,14 @@ impl InboundGroupSession {
pub fn new(
sender_key: &str,
signing_key: &str,
room_id: &str,
room_id: &RoomId,
session_key: &str,
) -> Result<Self, OlmGroupSessionError> {
Ok(InboundGroupSession {
inner: OlmInboundGroupSession::new(session_key)?,
sender_key: sender_key.to_owned(),
signing_key: signing_key.to_owned(),
room_id: room_id.to_owned(),
room_id: room_id.clone(),
forwarding_chains: None,
})
}
@ -235,7 +237,7 @@ impl InboundGroupSession {
pickle_mode: PicklingMode,
sender_key: String,
signing_key: String,
room_id: String,
room_id: RoomId,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
Ok(InboundGroupSession {

View file

@ -21,12 +21,13 @@ use tokio::sync::Mutex;
use super::{Account, CryptoStore, InboundGroupSession, Result, Session};
use crate::crypto::device::Device;
use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
use crate::identifiers::{RoomId, UserId};
#[derive(Debug)]
pub struct MemoryStore {
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
tracked_users: HashSet<String>,
tracked_users: HashSet<UserId>,
devices: DeviceStore,
}
@ -73,7 +74,7 @@ impl CryptoStore for MemoryStore {
async fn get_inbound_group_session(
&mut self,
room_id: &str,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
@ -82,19 +83,19 @@ impl CryptoStore for MemoryStore {
.get(room_id, sender_key, session_id))
}
fn tracked_users(&self) -> &HashSet<String> {
fn tracked_users(&self) -> &HashSet<UserId> {
&self.tracked_users
}
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
Ok(self.tracked_users.insert(user.to_string()))
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
Ok(self.tracked_users.insert(user.clone()))
}
async fn get_device(&self, user_id: &str, device_id: &str) -> Result<Option<Device>> {
async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>> {
Ok(self.devices.get(user_id, device_id))
}
async fn get_user_devices(&self, user_id: &str) -> Result<UserDevices> {
async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
Ok(self.devices.user_devices(user_id))
}

View file

@ -26,6 +26,7 @@ use tokio::sync::Mutex;
use super::device::Device;
use super::memory_stores::UserDevices;
use super::olm::{Account, InboundGroupSession, Session};
use crate::identifiers::{RoomId, UserId};
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
pub mod memorystore;
@ -75,13 +76,13 @@ pub trait CryptoStore: Debug + Send + Sync {
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
async fn get_inbound_group_session(
&mut self,
room_id: &str,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
fn tracked_users(&self) -> &HashSet<String>;
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool>;
fn tracked_users(&self) -> &HashSet<UserId>;
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
async fn save_device(&self, device: Device) -> Result<()>;
async fn get_device(&self, user_id: &str, device_id: &str) -> Result<Option<Device>>;
async fn get_user_devices(&self, user_id: &str) -> Result<UserDevices>;
async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>>;
async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices>;
}

View file

@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashSet;
use std::convert::TryFrom;
use std::path::{Path, PathBuf};
use std::result::Result as StdResult;
use std::sync::Arc;
@ -29,6 +30,7 @@ use zeroize::Zeroizing;
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session};
use crate::crypto::device::Device;
use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices};
use crate::identifiers::{RoomId, UserId};
pub struct SqliteStore {
user_id: Arc<String>,
@ -39,14 +41,14 @@ pub struct SqliteStore {
inbound_group_sessions: GroupSessionStore,
connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>,
tracked_users: HashSet<String>,
tracked_users: HashSet<UserId>,
}
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore {
pub async fn open<P: AsRef<Path>>(
user_id: &str,
user_id: &UserId,
device_id: &str,
path: P,
) -> Result<SqliteStore> {
@ -54,7 +56,7 @@ impl SqliteStore {
}
pub async fn open_with_passphrase<P: AsRef<Path>>(
user_id: &str,
user_id: &UserId,
device_id: &str,
path: P,
passphrase: String,
@ -69,7 +71,7 @@ impl SqliteStore {
}
async fn open_helper<P: AsRef<Path>>(
user_id: &str,
user_id: &UserId,
device_id: &str,
path: P,
passphrase: Option<Zeroizing<String>>,
@ -78,7 +80,7 @@ impl SqliteStore {
let connection = SqliteConnection::connect(url.as_ref()).await?;
let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()),
user_id: Arc::new(user_id.to_string()),
device_id: Arc::new(device_id.to_owned()),
account_id: None,
sessions: SessionStore::new(),
@ -230,7 +232,7 @@ impl SqliteStore {
self.get_pickle_mode(),
sender_key.to_string(),
signing_key.to_owned(),
room_id.to_owned(),
RoomId::try_from(room_id.as_str()).unwrap(),
)?)
})
.collect::<Result<Vec<InboundGroupSession>>>()?)
@ -302,8 +304,8 @@ impl CryptoStore for SqliteStore {
device_id = ?2
",
)
.bind(&*self.user_id)
.bind(&*self.device_id)
.bind(&*self.user_id.to_string())
.bind(&*self.device_id.to_string())
.bind(&pickle)
.bind(acc.shared)
.execute(&mut *connection)
@ -311,8 +313,8 @@ impl CryptoStore for SqliteStore {
let account_id: (i64,) =
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
.bind(&*self.user_id)
.bind(&*self.device_id)
.bind(&*self.user_id.to_string())
.bind(&*self.device_id.to_string())
.fetch_one(&mut *connection)
.await?;
@ -383,7 +385,7 @@ impl CryptoStore for SqliteStore {
.bind(account_id)
.bind(&session.sender_key)
.bind(&session.signing_key)
.bind(&session.room_id)
.bind(&session.room_id.to_string())
.bind(&pickle)
.execute(&mut *connection)
.await?;
@ -393,7 +395,7 @@ impl CryptoStore for SqliteStore {
async fn get_inbound_group_session(
&mut self,
room_id: &str,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
@ -402,19 +404,19 @@ impl CryptoStore for SqliteStore {
.get(room_id, sender_key, session_id))
}
fn tracked_users(&self) -> &HashSet<String> {
fn tracked_users(&self) -> &HashSet<UserId> {
&self.tracked_users
}
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
Ok(self.tracked_users.insert(user.to_string()))
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
Ok(self.tracked_users.insert(user.clone()))
}
async fn get_device(&self, _user_id: &str, _device_id: &str) -> Result<Option<Device>> {
async fn get_device(&self, _user_id: &UserId, _device_id: &str) -> Result<Option<Device>> {
todo!()
}
async fn get_user_devices(&self, _user_id: &str) -> Result<UserDevices> {
async fn get_user_devices(&self, _user_id: &UserId) -> Result<UserDevices> {
todo!()
}
@ -442,7 +444,9 @@ mod test {
use tempfile::tempdir;
use tokio::sync::Mutex;
use super::{Account, CryptoStore, InboundGroupSession, Session, SqliteStore};
use super::{
Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom, UserId,
};
static USER_ID: &str = "@example:localhost";
static DEVICE_ID: &str = "DEVICEID";
@ -450,7 +454,7 @@ mod test {
async fn get_store() -> SqliteStore {
let tmpdir = tempdir().unwrap();
let tmpdir_path = tmpdir.path().to_str().unwrap();
SqliteStore::open(USER_ID, DEVICE_ID, tmpdir_path)
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, tmpdir_path)
.await
.expect("Can't create store")
}
@ -501,7 +505,7 @@ mod test {
async fn create_store() {
let tmpdir = tempdir().unwrap();
let tmpdir_path = tmpdir.path().to_str().unwrap();
let _ = SqliteStore::open("@example:localhost", "DEVICEID", tmpdir_path)
let _ = SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), "DEVICEID", tmpdir_path)
.await
.expect("Can't create store");
}
@ -626,7 +630,7 @@ mod test {
let session = InboundGroupSession::new(
identity_keys.curve25519(),
identity_keys.ed25519(),
"!test:localhost",
&RoomId::try_from("!test:localhost").unwrap(),
&outbound_session.session_key(),
)
.expect("Can't create session");
@ -647,7 +651,7 @@ mod test {
let session = InboundGroupSession::new(
identity_keys.curve25519(),
identity_keys.ed25519(),
"!test:localhost",
&RoomId::try_from("!test:localhost").unwrap(),
&outbound_session.session_key(),
)
.expect("Can't create session");

View file

@ -65,7 +65,7 @@ use tokio::sync::Mutex;
/// } = event.lock().await.deref()
/// {
/// let rooms = room.lock().await;
/// let member = rooms.members.get(&sender.to_string()).unwrap();
/// let member = rooms.members.get(&sender).unwrap();
/// println!(
/// "{}: {}",
/// member