crypto compiles, tests pass

master
Devin R 2020-04-03 11:00:37 -04:00
parent db62c1fa25
commit eff322c0c5
9 changed files with 83 additions and 79 deletions

View File

@ -623,7 +623,7 @@ impl AsyncClient {
/// * `data` - The content of the message. /// * `data` - The content of the message.
pub async fn room_send( pub async fn room_send(
&mut self, &mut self,
room_id: &str, room_id: &RoomId,
data: MessageEventContent, data: MessageEventContent,
) -> Result<create_message_event::Response> { ) -> Result<create_message_event::Response> {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
@ -658,7 +658,7 @@ impl AsyncClient {
} }
let request = create_message_event::Request { let request = create_message_event::Request {
room_id: RoomId::try_from(room_id).unwrap(), room_id: room_id.clone(),
event_type: EventType::RoomMessage, event_type: EventType::RoomMessage,
txn_id: self.transaction_id().to_string(), txn_id: self.transaction_id().to_string(),
data, data,
@ -771,7 +771,7 @@ impl AsyncClient {
let mut device_keys: HashMap<UserId, Vec<DeviceId>> = HashMap::new(); let mut device_keys: HashMap<UserId, Vec<DeviceId>> = HashMap::new();
for user in users_for_query.drain() { for user in users_for_query.drain() {
device_keys.insert(UserId::try_from(user.as_ref()).unwrap(), Vec::new()); device_keys.insert(user, Vec::new());
} }
let request = get_keys::Request { let request = get_keys::Request {

View File

@ -400,8 +400,8 @@ impl Client {
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: impl Iterator<Item = &String>, users: impl Iterator<Item = &UserId>,
) -> HashMap<RumaUserId, HashMap<DeviceId, KeyAlgorithm>> { ) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> {
let mut olm = self.olm.lock().await; let mut olm = self.olm.lock().await;
match &mut *olm { match &mut *olm {
@ -431,7 +431,7 @@ impl Client {
/// Returns an empty error if no keys need to be queried. /// Returns an empty error if no keys need to be queried.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn users_for_key_query(&self) -> StdResult<HashSet<String>, ()> { pub async fn users_for_key_query(&self) -> StdResult<HashSet<UserId>, ()> {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &*olm { match &*olm {

View File

@ -71,7 +71,7 @@ pub struct OlmMachine {
store: Box<dyn CryptoStore>, store: Box<dyn CryptoStore>,
/// Set of users that we need to query keys for. This is a subset of /// Set of users that we need to query keys for. This is a subset of
/// the tracked users in the CryptoStore. /// the tracked users in the CryptoStore.
users_for_key_query: HashSet<String>, users_for_key_query: HashSet<UserId>,
} }
impl OlmMachine { impl OlmMachine {
@ -101,8 +101,7 @@ impl OlmMachine {
passphrase: String, passphrase: String,
) -> Result<Self> { ) -> Result<Self> {
let mut store = let mut store =
SqliteStore::open_with_passphrase(&user_id.to_string(), device_id, path, passphrase) SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?;
.await?;
let account = match store.load_account().await? { let account = match store.load_account().await? {
Some(a) => { Some(a) => {
@ -183,12 +182,12 @@ impl OlmMachine {
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&mut self, &mut self,
users: impl Iterator<Item = &String>, users: impl Iterator<Item = &UserId>,
) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> { ) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> {
let mut missing = HashMap::new(); let mut missing = HashMap::new();
for user_id in users { for user_id in users {
let user_devices = self.store.get_user_devices(&user_id).await.unwrap(); let user_devices = self.store.get_user_devices(user_id).await.unwrap();
for device in user_devices.devices() { for device in user_devices.devices() {
let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) {
@ -206,12 +205,11 @@ impl OlmMachine {
}; };
if is_missing { if is_missing {
let user_id = UserId::try_from(user_id.as_ref()).unwrap(); if !missing.contains_key(user_id) {
if !missing.contains_key(&user_id) { missing.insert(user_id.clone(), HashMap::new());
missing.insert(user_id.to_owned(), HashMap::new());
} }
let user_map = missing.get_mut(&user_id).unwrap(); let user_map = missing.get_mut(user_id).unwrap();
user_map.insert( user_map.insert(
device.device_id().to_owned(), device.device_id().to_owned(),
KeyAlgorithm::SignedCurve25519, KeyAlgorithm::SignedCurve25519,
@ -233,7 +231,7 @@ impl OlmMachine {
for (device_id, key_map) in user_devices { for (device_id, key_map) in user_devices {
let device = if let Some(d) = self let device = if let Some(d) = self
.store .store
.get_device(&user_id.to_string(), device_id) .get_device(&user_id, device_id)
.await .await
.expect("Can't get devices") .expect("Can't get devices")
{ {
@ -346,8 +344,7 @@ impl OlmMachine {
let mut changed_devices = Vec::new(); let mut changed_devices = Vec::new();
for (user_id, device_map) in &response.device_keys { for (user_id, device_map) in &response.device_keys {
let user_id_string = user_id.to_string(); self.users_for_key_query.remove(&user_id);
self.users_for_key_query.remove(&user_id_string);
for (device_id, device_keys) in device_map.iter() { for (device_id, device_keys) in device_map.iter() {
// We don't need our own device in the device store. // We don't need our own device in the device store.
@ -393,7 +390,7 @@ impl OlmMachine {
let device = self let device = self
.store .store
.get_device(&user_id_string, device_id) .get_device(&user_id, device_id)
.await .await
.expect("Can't load device"); .expect("Can't load device");
@ -407,7 +404,7 @@ impl OlmMachine {
} }
let current_devices: HashSet<&String> = device_map.keys().collect(); let current_devices: HashSet<&String> = device_map.keys().collect();
let stored_devices = self.store.get_user_devices(&user_id_string).await.unwrap(); let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
let stored_devices_set: HashSet<&String> = stored_devices.keys().collect(); let stored_devices_set: HashSet<&String> = stored_devices.keys().collect();
let deleted_devices = stored_devices_set.difference(&current_devices); let deleted_devices = stored_devices_set.difference(&current_devices);
@ -767,7 +764,7 @@ impl OlmMachine {
let session = InboundGroupSession::new( let session = InboundGroupSession::new(
sender_key, sender_key,
signing_key, signing_key,
&event.content.room_id.to_string(), &event.content.room_id,
&event.content.session_key, &event.content.session_key,
)?; )?;
self.store.save_inbound_group_session(session).await?; self.store.save_inbound_group_session(session).await?;
@ -893,11 +890,7 @@ impl OlmMachine {
let session = self let session = self
.store .store
.get_inbound_group_session( .get_inbound_group_session(&room_id, &content.sender_key, &content.session_id)
&room_id.to_string(),
&content.sender_key,
&content.session_id,
)
.await?; .await?;
// TODO check if the olm session is wedged and re-request the key. // TODO check if the olm session is wedged and re-request the key.
let session = session.ok_or(OlmError::MissingSession)?; let session = session.ok_or(OlmError::MissingSession)?;
@ -936,7 +929,7 @@ impl OlmMachine {
/// Use the `mark_user_as_changed()` if the user really needs a key query. /// Use the `mark_user_as_changed()` if the user really needs a key query.
pub async fn update_tracked_users<'a, I>(&mut self, users: I) pub async fn update_tracked_users<'a, I>(&mut self, users: I)
where where
I: IntoIterator<Item = &'a String>, I: IntoIterator<Item = &'a UserId>,
{ {
for user in users { for user in users {
let ret = self.store.add_user_for_tracking(user).await; let ret = self.store.add_user_for_tracking(user).await;
@ -944,12 +937,12 @@ impl OlmMachine {
match ret { match ret {
Ok(newly_added) => { Ok(newly_added) => {
if newly_added { if newly_added {
self.users_for_key_query.insert(user.to_string()); self.users_for_key_query.insert(user.clone());
} }
} }
Err(e) => { Err(e) => {
warn!("Error storing users for tracking {}", e); warn!("Error storing users for tracking {}", e);
self.users_for_key_query.insert(user.to_string()); self.users_for_key_query.insert(user.clone());
} }
} }
} }
@ -961,7 +954,7 @@ impl OlmMachine {
} }
/// Get the set of users that we need to query keys for. /// Get the set of users that we need to query keys for.
pub fn users_for_key_query(&self) -> HashSet<String> { pub fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query.clone() self.users_for_key_query.clone()
} }
} }

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::Arc; use std::sync::Arc;
use dashmap::{DashMap, ReadOnlyView}; use dashmap::{DashMap, ReadOnlyView};
@ -20,6 +21,7 @@ use tokio::sync::Mutex;
use super::device::Device; use super::device::Device;
use super::olm::{InboundGroupSession, Session}; use super::olm::{InboundGroupSession, Session};
use crate::identifiers::{RoomId, UserId};
#[derive(Debug)] #[derive(Debug)]
pub struct SessionStore { pub struct SessionStore {
@ -59,7 +61,7 @@ impl SessionStore {
#[derive(Debug)] #[derive(Debug)]
pub struct GroupSessionStore { pub struct GroupSessionStore {
entries: HashMap<String, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>, entries: HashMap<RoomId, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
} }
impl GroupSessionStore { impl GroupSessionStore {
@ -89,7 +91,7 @@ impl GroupSessionStore {
pub fn get( pub fn get(
&self, &self,
room_id: &str, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Option<Arc<Mutex<InboundGroupSession>>> { ) -> Option<Arc<Mutex<InboundGroupSession>>> {
@ -101,7 +103,7 @@ impl GroupSessionStore {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct DeviceStore { pub struct DeviceStore {
entries: Arc<DashMap<String, DashMap<String, Device>>>, entries: Arc<DashMap<UserId, DashMap<String, Device>>>,
} }
pub struct UserDevices { pub struct UserDevices {
@ -130,26 +132,27 @@ impl DeviceStore {
} }
pub fn add(&self, device: Device) -> bool { pub fn add(&self, device: Device) -> bool {
if !self.entries.contains_key(device.user_id()) { let user_id = UserId::try_from(device.user_id()).unwrap();
self.entries
.insert(device.user_id().to_owned(), DashMap::new()); if !self.entries.contains_key(&user_id) {
self.entries.insert(user_id.clone(), DashMap::new());
} }
let device_map = self.entries.get_mut(device.user_id()).unwrap(); let device_map = self.entries.get_mut(&user_id).unwrap();
device_map device_map
.insert(device.device_id().to_owned(), device) .insert(device.device_id().to_owned(), device)
.is_some() .is_some()
} }
pub fn get(&self, user_id: &str, device_id: &str) -> Option<Device> { pub fn get(&self, user_id: &UserId, device_id: &str) -> Option<Device> {
self.entries self.entries
.get(user_id) .get(user_id)
.and_then(|m| m.get(device_id).map(|d| d.value().clone())) .and_then(|m| m.get(device_id).map(|d| d.value().clone()))
} }
pub fn user_devices(&self, user_id: &str) -> UserDevices { pub fn user_devices(&self, user_id: &UserId) -> UserDevices {
if !self.entries.contains_key(user_id) { if !self.entries.contains_key(user_id) {
self.entries.insert(user_id.to_owned(), DashMap::new()); self.entries.insert(user_id.clone(), DashMap::new());
} }
UserDevices { UserDevices {
entries: self.entries.get(user_id).unwrap().clone().into_read_only(), entries: self.entries.get(user_id).unwrap().clone().into_read_only(),

View File

@ -23,6 +23,8 @@ use olm_rs::PicklingMode;
use ruma_client_api::r0::keys::SignedKey; use ruma_client_api::r0::keys::SignedKey;
use crate::identifiers::{RoomId, UserId};
pub struct Account { pub struct Account {
inner: OlmAccount, inner: OlmAccount,
pub(crate) shared: bool, pub(crate) shared: bool,
@ -210,7 +212,7 @@ pub struct InboundGroupSession {
inner: OlmInboundGroupSession, inner: OlmInboundGroupSession,
pub(crate) sender_key: String, pub(crate) sender_key: String,
pub(crate) signing_key: String, pub(crate) signing_key: String,
pub(crate) room_id: String, pub(crate) room_id: RoomId,
forwarding_chains: Option<Vec<String>>, forwarding_chains: Option<Vec<String>>,
} }
@ -218,14 +220,14 @@ impl InboundGroupSession {
pub fn new( pub fn new(
sender_key: &str, sender_key: &str,
signing_key: &str, signing_key: &str,
room_id: &str, room_id: &RoomId,
session_key: &str, session_key: &str,
) -> Result<Self, OlmGroupSessionError> { ) -> Result<Self, OlmGroupSessionError> {
Ok(InboundGroupSession { Ok(InboundGroupSession {
inner: OlmInboundGroupSession::new(session_key)?, inner: OlmInboundGroupSession::new(session_key)?,
sender_key: sender_key.to_owned(), sender_key: sender_key.to_owned(),
signing_key: signing_key.to_owned(), signing_key: signing_key.to_owned(),
room_id: room_id.to_owned(), room_id: room_id.clone(),
forwarding_chains: None, forwarding_chains: None,
}) })
} }
@ -235,7 +237,7 @@ impl InboundGroupSession {
pickle_mode: PicklingMode, pickle_mode: PicklingMode,
sender_key: String, sender_key: String,
signing_key: String, signing_key: String,
room_id: String, room_id: RoomId,
) -> Result<Self, OlmGroupSessionError> { ) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?; let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
Ok(InboundGroupSession { Ok(InboundGroupSession {

View File

@ -21,12 +21,13 @@ use tokio::sync::Mutex;
use super::{Account, CryptoStore, InboundGroupSession, Result, Session}; use super::{Account, CryptoStore, InboundGroupSession, Result, Session};
use crate::crypto::device::Device; use crate::crypto::device::Device;
use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
use crate::identifiers::{RoomId, UserId};
#[derive(Debug)] #[derive(Debug)]
pub struct MemoryStore { pub struct MemoryStore {
sessions: SessionStore, sessions: SessionStore,
inbound_group_sessions: GroupSessionStore, inbound_group_sessions: GroupSessionStore,
tracked_users: HashSet<String>, tracked_users: HashSet<UserId>,
devices: DeviceStore, devices: DeviceStore,
} }
@ -73,7 +74,7 @@ impl CryptoStore for MemoryStore {
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &mut self,
room_id: &str, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> { ) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
@ -82,19 +83,19 @@ impl CryptoStore for MemoryStore {
.get(room_id, sender_key, session_id)) .get(room_id, sender_key, session_id))
} }
fn tracked_users(&self) -> &HashSet<String> { fn tracked_users(&self) -> &HashSet<UserId> {
&self.tracked_users &self.tracked_users
} }
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> { async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
Ok(self.tracked_users.insert(user.to_string())) Ok(self.tracked_users.insert(user.clone()))
} }
async fn get_device(&self, user_id: &str, device_id: &str) -> Result<Option<Device>> { async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>> {
Ok(self.devices.get(user_id, device_id)) Ok(self.devices.get(user_id, device_id))
} }
async fn get_user_devices(&self, user_id: &str) -> Result<UserDevices> { async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
Ok(self.devices.user_devices(user_id)) Ok(self.devices.user_devices(user_id))
} }

View File

@ -26,6 +26,7 @@ use tokio::sync::Mutex;
use super::device::Device; use super::device::Device;
use super::memory_stores::UserDevices; use super::memory_stores::UserDevices;
use super::olm::{Account, InboundGroupSession, Session}; use super::olm::{Account, InboundGroupSession, Session};
use crate::identifiers::{RoomId, UserId};
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
pub mod memorystore; pub mod memorystore;
@ -75,13 +76,13 @@ pub trait CryptoStore: Debug + Send + Sync {
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>; async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &mut self,
room_id: &str, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>; ) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
fn tracked_users(&self) -> &HashSet<String>; fn tracked_users(&self) -> &HashSet<UserId>;
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool>; async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
async fn save_device(&self, device: Device) -> Result<()>; async fn save_device(&self, device: Device) -> Result<()>;
async fn get_device(&self, user_id: &str, device_id: &str) -> Result<Option<Device>>; async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>>;
async fn get_user_devices(&self, user_id: &str) -> Result<UserDevices>; async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices>;
} }

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use std::collections::HashSet; use std::collections::HashSet;
use std::convert::TryFrom;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::sync::Arc; use std::sync::Arc;
@ -29,6 +30,7 @@ use zeroize::Zeroizing;
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session};
use crate::crypto::device::Device; use crate::crypto::device::Device;
use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices}; use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices};
use crate::identifiers::{RoomId, UserId};
pub struct SqliteStore { pub struct SqliteStore {
user_id: Arc<String>, user_id: Arc<String>,
@ -39,14 +41,14 @@ pub struct SqliteStore {
inbound_group_sessions: GroupSessionStore, inbound_group_sessions: GroupSessionStore,
connection: Arc<Mutex<SqliteConnection>>, connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>, pickle_passphrase: Option<Zeroizing<String>>,
tracked_users: HashSet<String>, tracked_users: HashSet<UserId>,
} }
static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore { impl SqliteStore {
pub async fn open<P: AsRef<Path>>( pub async fn open<P: AsRef<Path>>(
user_id: &str, user_id: &UserId,
device_id: &str, device_id: &str,
path: P, path: P,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
@ -54,7 +56,7 @@ impl SqliteStore {
} }
pub async fn open_with_passphrase<P: AsRef<Path>>( pub async fn open_with_passphrase<P: AsRef<Path>>(
user_id: &str, user_id: &UserId,
device_id: &str, device_id: &str,
path: P, path: P,
passphrase: String, passphrase: String,
@ -69,7 +71,7 @@ impl SqliteStore {
} }
async fn open_helper<P: AsRef<Path>>( async fn open_helper<P: AsRef<Path>>(
user_id: &str, user_id: &UserId,
device_id: &str, device_id: &str,
path: P, path: P,
passphrase: Option<Zeroizing<String>>, passphrase: Option<Zeroizing<String>>,
@ -78,7 +80,7 @@ impl SqliteStore {
let connection = SqliteConnection::connect(url.as_ref()).await?; let connection = SqliteConnection::connect(url.as_ref()).await?;
let store = SqliteStore { let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_string()),
device_id: Arc::new(device_id.to_owned()), device_id: Arc::new(device_id.to_owned()),
account_id: None, account_id: None,
sessions: SessionStore::new(), sessions: SessionStore::new(),
@ -230,7 +232,7 @@ impl SqliteStore {
self.get_pickle_mode(), self.get_pickle_mode(),
sender_key.to_string(), sender_key.to_string(),
signing_key.to_owned(), signing_key.to_owned(),
room_id.to_owned(), RoomId::try_from(room_id.as_str()).unwrap(),
)?) )?)
}) })
.collect::<Result<Vec<InboundGroupSession>>>()?) .collect::<Result<Vec<InboundGroupSession>>>()?)
@ -302,8 +304,8 @@ impl CryptoStore for SqliteStore {
device_id = ?2 device_id = ?2
", ",
) )
.bind(&*self.user_id) .bind(&*self.user_id.to_string())
.bind(&*self.device_id) .bind(&*self.device_id.to_string())
.bind(&pickle) .bind(&pickle)
.bind(acc.shared) .bind(acc.shared)
.execute(&mut *connection) .execute(&mut *connection)
@ -311,8 +313,8 @@ impl CryptoStore for SqliteStore {
let account_id: (i64,) = let account_id: (i64,) =
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
.bind(&*self.user_id) .bind(&*self.user_id.to_string())
.bind(&*self.device_id) .bind(&*self.device_id.to_string())
.fetch_one(&mut *connection) .fetch_one(&mut *connection)
.await?; .await?;
@ -383,7 +385,7 @@ impl CryptoStore for SqliteStore {
.bind(account_id) .bind(account_id)
.bind(&session.sender_key) .bind(&session.sender_key)
.bind(&session.signing_key) .bind(&session.signing_key)
.bind(&session.room_id) .bind(&session.room_id.to_string())
.bind(&pickle) .bind(&pickle)
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;
@ -393,7 +395,7 @@ impl CryptoStore for SqliteStore {
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &mut self,
room_id: &str, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> { ) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
@ -402,19 +404,19 @@ impl CryptoStore for SqliteStore {
.get(room_id, sender_key, session_id)) .get(room_id, sender_key, session_id))
} }
fn tracked_users(&self) -> &HashSet<String> { fn tracked_users(&self) -> &HashSet<UserId> {
&self.tracked_users &self.tracked_users
} }
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> { async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
Ok(self.tracked_users.insert(user.to_string())) Ok(self.tracked_users.insert(user.clone()))
} }
async fn get_device(&self, _user_id: &str, _device_id: &str) -> Result<Option<Device>> { async fn get_device(&self, _user_id: &UserId, _device_id: &str) -> Result<Option<Device>> {
todo!() todo!()
} }
async fn get_user_devices(&self, _user_id: &str) -> Result<UserDevices> { async fn get_user_devices(&self, _user_id: &UserId) -> Result<UserDevices> {
todo!() todo!()
} }
@ -442,7 +444,9 @@ mod test {
use tempfile::tempdir; use tempfile::tempdir;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use super::{Account, CryptoStore, InboundGroupSession, Session, SqliteStore}; use super::{
Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom, UserId,
};
static USER_ID: &str = "@example:localhost"; static USER_ID: &str = "@example:localhost";
static DEVICE_ID: &str = "DEVICEID"; static DEVICE_ID: &str = "DEVICEID";
@ -450,7 +454,7 @@ mod test {
async fn get_store() -> SqliteStore { async fn get_store() -> SqliteStore {
let tmpdir = tempdir().unwrap(); let tmpdir = tempdir().unwrap();
let tmpdir_path = tmpdir.path().to_str().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap();
SqliteStore::open(USER_ID, DEVICE_ID, tmpdir_path) SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, tmpdir_path)
.await .await
.expect("Can't create store") .expect("Can't create store")
} }
@ -501,7 +505,7 @@ mod test {
async fn create_store() { async fn create_store() {
let tmpdir = tempdir().unwrap(); let tmpdir = tempdir().unwrap();
let tmpdir_path = tmpdir.path().to_str().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap();
let _ = SqliteStore::open("@example:localhost", "DEVICEID", tmpdir_path) let _ = SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), "DEVICEID", tmpdir_path)
.await .await
.expect("Can't create store"); .expect("Can't create store");
} }
@ -626,7 +630,7 @@ mod test {
let session = InboundGroupSession::new( let session = InboundGroupSession::new(
identity_keys.curve25519(), identity_keys.curve25519(),
identity_keys.ed25519(), identity_keys.ed25519(),
"!test:localhost", &RoomId::try_from("!test:localhost").unwrap(),
&outbound_session.session_key(), &outbound_session.session_key(),
) )
.expect("Can't create session"); .expect("Can't create session");
@ -647,7 +651,7 @@ mod test {
let session = InboundGroupSession::new( let session = InboundGroupSession::new(
identity_keys.curve25519(), identity_keys.curve25519(),
identity_keys.ed25519(), identity_keys.ed25519(),
"!test:localhost", &RoomId::try_from("!test:localhost").unwrap(),
&outbound_session.session_key(), &outbound_session.session_key(),
) )
.expect("Can't create session"); .expect("Can't create session");

View File

@ -65,7 +65,7 @@ use tokio::sync::Mutex;
/// } = event.lock().await.deref() /// } = event.lock().await.deref()
/// { /// {
/// let rooms = room.lock().await; /// let rooms = room.lock().await;
/// let member = rooms.members.get(&sender.to_string()).unwrap(); /// let member = rooms.members.get(&sender).unwrap();
/// println!( /// println!(
/// "{}: {}", /// "{}: {}",
/// member /// member