crypto compiles, tests pass
parent
db62c1fa25
commit
eff322c0c5
|
@ -623,7 +623,7 @@ impl AsyncClient {
|
||||||
/// * `data` - The content of the message.
|
/// * `data` - The content of the message.
|
||||||
pub async fn room_send(
|
pub async fn room_send(
|
||||||
&mut self,
|
&mut self,
|
||||||
room_id: &str,
|
room_id: &RoomId,
|
||||||
data: MessageEventContent,
|
data: MessageEventContent,
|
||||||
) -> Result<create_message_event::Response> {
|
) -> Result<create_message_event::Response> {
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
|
@ -658,7 +658,7 @@ impl AsyncClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = create_message_event::Request {
|
let request = create_message_event::Request {
|
||||||
room_id: RoomId::try_from(room_id).unwrap(),
|
room_id: room_id.clone(),
|
||||||
event_type: EventType::RoomMessage,
|
event_type: EventType::RoomMessage,
|
||||||
txn_id: self.transaction_id().to_string(),
|
txn_id: self.transaction_id().to_string(),
|
||||||
data,
|
data,
|
||||||
|
@ -771,7 +771,7 @@ impl AsyncClient {
|
||||||
let mut device_keys: HashMap<UserId, Vec<DeviceId>> = HashMap::new();
|
let mut device_keys: HashMap<UserId, Vec<DeviceId>> = HashMap::new();
|
||||||
|
|
||||||
for user in users_for_query.drain() {
|
for user in users_for_query.drain() {
|
||||||
device_keys.insert(UserId::try_from(user.as_ref()).unwrap(), Vec::new());
|
device_keys.insert(user, Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = get_keys::Request {
|
let request = get_keys::Request {
|
||||||
|
|
|
@ -400,8 +400,8 @@ impl Client {
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
||||||
pub async fn get_missing_sessions(
|
pub async fn get_missing_sessions(
|
||||||
&self,
|
&self,
|
||||||
users: impl Iterator<Item = &String>,
|
users: impl Iterator<Item = &UserId>,
|
||||||
) -> HashMap<RumaUserId, HashMap<DeviceId, KeyAlgorithm>> {
|
) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> {
|
||||||
let mut olm = self.olm.lock().await;
|
let mut olm = self.olm.lock().await;
|
||||||
|
|
||||||
match &mut *olm {
|
match &mut *olm {
|
||||||
|
@ -431,7 +431,7 @@ impl Client {
|
||||||
/// Returns an empty error if no keys need to be queried.
|
/// Returns an empty error if no keys need to be queried.
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
|
||||||
pub async fn users_for_key_query(&self) -> StdResult<HashSet<String>, ()> {
|
pub async fn users_for_key_query(&self) -> StdResult<HashSet<UserId>, ()> {
|
||||||
let olm = self.olm.lock().await;
|
let olm = self.olm.lock().await;
|
||||||
|
|
||||||
match &*olm {
|
match &*olm {
|
||||||
|
|
|
@ -71,7 +71,7 @@ pub struct OlmMachine {
|
||||||
store: Box<dyn CryptoStore>,
|
store: Box<dyn CryptoStore>,
|
||||||
/// Set of users that we need to query keys for. This is a subset of
|
/// Set of users that we need to query keys for. This is a subset of
|
||||||
/// the tracked users in the CryptoStore.
|
/// the tracked users in the CryptoStore.
|
||||||
users_for_key_query: HashSet<String>,
|
users_for_key_query: HashSet<UserId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OlmMachine {
|
impl OlmMachine {
|
||||||
|
@ -101,8 +101,7 @@ impl OlmMachine {
|
||||||
passphrase: String,
|
passphrase: String,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let mut store =
|
let mut store =
|
||||||
SqliteStore::open_with_passphrase(&user_id.to_string(), device_id, path, passphrase)
|
SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
let account = match store.load_account().await? {
|
let account = match store.load_account().await? {
|
||||||
Some(a) => {
|
Some(a) => {
|
||||||
|
@ -183,12 +182,12 @@ impl OlmMachine {
|
||||||
|
|
||||||
pub async fn get_missing_sessions(
|
pub async fn get_missing_sessions(
|
||||||
&mut self,
|
&mut self,
|
||||||
users: impl Iterator<Item = &String>,
|
users: impl Iterator<Item = &UserId>,
|
||||||
) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> {
|
) -> HashMap<UserId, HashMap<DeviceId, KeyAlgorithm>> {
|
||||||
let mut missing = HashMap::new();
|
let mut missing = HashMap::new();
|
||||||
|
|
||||||
for user_id in users {
|
for user_id in users {
|
||||||
let user_devices = self.store.get_user_devices(&user_id).await.unwrap();
|
let user_devices = self.store.get_user_devices(user_id).await.unwrap();
|
||||||
|
|
||||||
for device in user_devices.devices() {
|
for device in user_devices.devices() {
|
||||||
let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) {
|
let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) {
|
||||||
|
@ -206,12 +205,11 @@ impl OlmMachine {
|
||||||
};
|
};
|
||||||
|
|
||||||
if is_missing {
|
if is_missing {
|
||||||
let user_id = UserId::try_from(user_id.as_ref()).unwrap();
|
if !missing.contains_key(user_id) {
|
||||||
if !missing.contains_key(&user_id) {
|
missing.insert(user_id.clone(), HashMap::new());
|
||||||
missing.insert(user_id.to_owned(), HashMap::new());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let user_map = missing.get_mut(&user_id).unwrap();
|
let user_map = missing.get_mut(user_id).unwrap();
|
||||||
user_map.insert(
|
user_map.insert(
|
||||||
device.device_id().to_owned(),
|
device.device_id().to_owned(),
|
||||||
KeyAlgorithm::SignedCurve25519,
|
KeyAlgorithm::SignedCurve25519,
|
||||||
|
@ -233,7 +231,7 @@ impl OlmMachine {
|
||||||
for (device_id, key_map) in user_devices {
|
for (device_id, key_map) in user_devices {
|
||||||
let device = if let Some(d) = self
|
let device = if let Some(d) = self
|
||||||
.store
|
.store
|
||||||
.get_device(&user_id.to_string(), device_id)
|
.get_device(&user_id, device_id)
|
||||||
.await
|
.await
|
||||||
.expect("Can't get devices")
|
.expect("Can't get devices")
|
||||||
{
|
{
|
||||||
|
@ -346,8 +344,7 @@ impl OlmMachine {
|
||||||
let mut changed_devices = Vec::new();
|
let mut changed_devices = Vec::new();
|
||||||
|
|
||||||
for (user_id, device_map) in &response.device_keys {
|
for (user_id, device_map) in &response.device_keys {
|
||||||
let user_id_string = user_id.to_string();
|
self.users_for_key_query.remove(&user_id);
|
||||||
self.users_for_key_query.remove(&user_id_string);
|
|
||||||
|
|
||||||
for (device_id, device_keys) in device_map.iter() {
|
for (device_id, device_keys) in device_map.iter() {
|
||||||
// We don't need our own device in the device store.
|
// We don't need our own device in the device store.
|
||||||
|
@ -393,7 +390,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
let device = self
|
let device = self
|
||||||
.store
|
.store
|
||||||
.get_device(&user_id_string, device_id)
|
.get_device(&user_id, device_id)
|
||||||
.await
|
.await
|
||||||
.expect("Can't load device");
|
.expect("Can't load device");
|
||||||
|
|
||||||
|
@ -407,7 +404,7 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
let current_devices: HashSet<&String> = device_map.keys().collect();
|
let current_devices: HashSet<&String> = device_map.keys().collect();
|
||||||
let stored_devices = self.store.get_user_devices(&user_id_string).await.unwrap();
|
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
|
||||||
let stored_devices_set: HashSet<&String> = stored_devices.keys().collect();
|
let stored_devices_set: HashSet<&String> = stored_devices.keys().collect();
|
||||||
|
|
||||||
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
||||||
|
@ -767,7 +764,7 @@ impl OlmMachine {
|
||||||
let session = InboundGroupSession::new(
|
let session = InboundGroupSession::new(
|
||||||
sender_key,
|
sender_key,
|
||||||
signing_key,
|
signing_key,
|
||||||
&event.content.room_id.to_string(),
|
&event.content.room_id,
|
||||||
&event.content.session_key,
|
&event.content.session_key,
|
||||||
)?;
|
)?;
|
||||||
self.store.save_inbound_group_session(session).await?;
|
self.store.save_inbound_group_session(session).await?;
|
||||||
|
@ -893,11 +890,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
let session = self
|
let session = self
|
||||||
.store
|
.store
|
||||||
.get_inbound_group_session(
|
.get_inbound_group_session(&room_id, &content.sender_key, &content.session_id)
|
||||||
&room_id.to_string(),
|
|
||||||
&content.sender_key,
|
|
||||||
&content.session_id,
|
|
||||||
)
|
|
||||||
.await?;
|
.await?;
|
||||||
// TODO check if the olm session is wedged and re-request the key.
|
// TODO check if the olm session is wedged and re-request the key.
|
||||||
let session = session.ok_or(OlmError::MissingSession)?;
|
let session = session.ok_or(OlmError::MissingSession)?;
|
||||||
|
@ -936,7 +929,7 @@ impl OlmMachine {
|
||||||
/// Use the `mark_user_as_changed()` if the user really needs a key query.
|
/// Use the `mark_user_as_changed()` if the user really needs a key query.
|
||||||
pub async fn update_tracked_users<'a, I>(&mut self, users: I)
|
pub async fn update_tracked_users<'a, I>(&mut self, users: I)
|
||||||
where
|
where
|
||||||
I: IntoIterator<Item = &'a String>,
|
I: IntoIterator<Item = &'a UserId>,
|
||||||
{
|
{
|
||||||
for user in users {
|
for user in users {
|
||||||
let ret = self.store.add_user_for_tracking(user).await;
|
let ret = self.store.add_user_for_tracking(user).await;
|
||||||
|
@ -944,12 +937,12 @@ impl OlmMachine {
|
||||||
match ret {
|
match ret {
|
||||||
Ok(newly_added) => {
|
Ok(newly_added) => {
|
||||||
if newly_added {
|
if newly_added {
|
||||||
self.users_for_key_query.insert(user.to_string());
|
self.users_for_key_query.insert(user.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Error storing users for tracking {}", e);
|
warn!("Error storing users for tracking {}", e);
|
||||||
self.users_for_key_query.insert(user.to_string());
|
self.users_for_key_query.insert(user.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -961,7 +954,7 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the set of users that we need to query keys for.
|
/// Get the set of users that we need to query keys for.
|
||||||
pub fn users_for_key_query(&self) -> HashSet<String> {
|
pub fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
self.users_for_key_query.clone()
|
self.users_for_key_query.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::convert::TryFrom;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use dashmap::{DashMap, ReadOnlyView};
|
use dashmap::{DashMap, ReadOnlyView};
|
||||||
|
@ -20,6 +21,7 @@ use tokio::sync::Mutex;
|
||||||
|
|
||||||
use super::device::Device;
|
use super::device::Device;
|
||||||
use super::olm::{InboundGroupSession, Session};
|
use super::olm::{InboundGroupSession, Session};
|
||||||
|
use crate::identifiers::{RoomId, UserId};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SessionStore {
|
pub struct SessionStore {
|
||||||
|
@ -59,7 +61,7 @@ impl SessionStore {
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GroupSessionStore {
|
pub struct GroupSessionStore {
|
||||||
entries: HashMap<String, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
|
entries: HashMap<RoomId, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GroupSessionStore {
|
impl GroupSessionStore {
|
||||||
|
@ -89,7 +91,7 @@ impl GroupSessionStore {
|
||||||
|
|
||||||
pub fn get(
|
pub fn get(
|
||||||
&self,
|
&self,
|
||||||
room_id: &str,
|
room_id: &RoomId,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Option<Arc<Mutex<InboundGroupSession>>> {
|
) -> Option<Arc<Mutex<InboundGroupSession>>> {
|
||||||
|
@ -101,7 +103,7 @@ impl GroupSessionStore {
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct DeviceStore {
|
pub struct DeviceStore {
|
||||||
entries: Arc<DashMap<String, DashMap<String, Device>>>,
|
entries: Arc<DashMap<UserId, DashMap<String, Device>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UserDevices {
|
pub struct UserDevices {
|
||||||
|
@ -130,26 +132,27 @@ impl DeviceStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add(&self, device: Device) -> bool {
|
pub fn add(&self, device: Device) -> bool {
|
||||||
if !self.entries.contains_key(device.user_id()) {
|
let user_id = UserId::try_from(device.user_id()).unwrap();
|
||||||
self.entries
|
|
||||||
.insert(device.user_id().to_owned(), DashMap::new());
|
if !self.entries.contains_key(&user_id) {
|
||||||
|
self.entries.insert(user_id.clone(), DashMap::new());
|
||||||
}
|
}
|
||||||
let device_map = self.entries.get_mut(device.user_id()).unwrap();
|
let device_map = self.entries.get_mut(&user_id).unwrap();
|
||||||
|
|
||||||
device_map
|
device_map
|
||||||
.insert(device.device_id().to_owned(), device)
|
.insert(device.device_id().to_owned(), device)
|
||||||
.is_some()
|
.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, user_id: &str, device_id: &str) -> Option<Device> {
|
pub fn get(&self, user_id: &UserId, device_id: &str) -> Option<Device> {
|
||||||
self.entries
|
self.entries
|
||||||
.get(user_id)
|
.get(user_id)
|
||||||
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
|
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn user_devices(&self, user_id: &str) -> UserDevices {
|
pub fn user_devices(&self, user_id: &UserId) -> UserDevices {
|
||||||
if !self.entries.contains_key(user_id) {
|
if !self.entries.contains_key(user_id) {
|
||||||
self.entries.insert(user_id.to_owned(), DashMap::new());
|
self.entries.insert(user_id.clone(), DashMap::new());
|
||||||
}
|
}
|
||||||
UserDevices {
|
UserDevices {
|
||||||
entries: self.entries.get(user_id).unwrap().clone().into_read_only(),
|
entries: self.entries.get(user_id).unwrap().clone().into_read_only(),
|
||||||
|
|
|
@ -23,6 +23,8 @@ use olm_rs::PicklingMode;
|
||||||
|
|
||||||
use ruma_client_api::r0::keys::SignedKey;
|
use ruma_client_api::r0::keys::SignedKey;
|
||||||
|
|
||||||
|
use crate::identifiers::{RoomId, UserId};
|
||||||
|
|
||||||
pub struct Account {
|
pub struct Account {
|
||||||
inner: OlmAccount,
|
inner: OlmAccount,
|
||||||
pub(crate) shared: bool,
|
pub(crate) shared: bool,
|
||||||
|
@ -210,7 +212,7 @@ pub struct InboundGroupSession {
|
||||||
inner: OlmInboundGroupSession,
|
inner: OlmInboundGroupSession,
|
||||||
pub(crate) sender_key: String,
|
pub(crate) sender_key: String,
|
||||||
pub(crate) signing_key: String,
|
pub(crate) signing_key: String,
|
||||||
pub(crate) room_id: String,
|
pub(crate) room_id: RoomId,
|
||||||
forwarding_chains: Option<Vec<String>>,
|
forwarding_chains: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,14 +220,14 @@ impl InboundGroupSession {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
signing_key: &str,
|
signing_key: &str,
|
||||||
room_id: &str,
|
room_id: &RoomId,
|
||||||
session_key: &str,
|
session_key: &str,
|
||||||
) -> Result<Self, OlmGroupSessionError> {
|
) -> Result<Self, OlmGroupSessionError> {
|
||||||
Ok(InboundGroupSession {
|
Ok(InboundGroupSession {
|
||||||
inner: OlmInboundGroupSession::new(session_key)?,
|
inner: OlmInboundGroupSession::new(session_key)?,
|
||||||
sender_key: sender_key.to_owned(),
|
sender_key: sender_key.to_owned(),
|
||||||
signing_key: signing_key.to_owned(),
|
signing_key: signing_key.to_owned(),
|
||||||
room_id: room_id.to_owned(),
|
room_id: room_id.clone(),
|
||||||
forwarding_chains: None,
|
forwarding_chains: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -235,7 +237,7 @@ impl InboundGroupSession {
|
||||||
pickle_mode: PicklingMode,
|
pickle_mode: PicklingMode,
|
||||||
sender_key: String,
|
sender_key: String,
|
||||||
signing_key: String,
|
signing_key: String,
|
||||||
room_id: String,
|
room_id: RoomId,
|
||||||
) -> Result<Self, OlmGroupSessionError> {
|
) -> Result<Self, OlmGroupSessionError> {
|
||||||
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
|
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
|
||||||
Ok(InboundGroupSession {
|
Ok(InboundGroupSession {
|
||||||
|
|
|
@ -21,12 +21,13 @@ use tokio::sync::Mutex;
|
||||||
use super::{Account, CryptoStore, InboundGroupSession, Result, Session};
|
use super::{Account, CryptoStore, InboundGroupSession, Result, Session};
|
||||||
use crate::crypto::device::Device;
|
use crate::crypto::device::Device;
|
||||||
use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
|
use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
|
||||||
|
use crate::identifiers::{RoomId, UserId};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MemoryStore {
|
pub struct MemoryStore {
|
||||||
sessions: SessionStore,
|
sessions: SessionStore,
|
||||||
inbound_group_sessions: GroupSessionStore,
|
inbound_group_sessions: GroupSessionStore,
|
||||||
tracked_users: HashSet<String>,
|
tracked_users: HashSet<UserId>,
|
||||||
devices: DeviceStore,
|
devices: DeviceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +74,7 @@ impl CryptoStore for MemoryStore {
|
||||||
|
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
&mut self,
|
&mut self,
|
||||||
room_id: &str,
|
room_id: &RoomId,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
||||||
|
@ -82,19 +83,19 @@ impl CryptoStore for MemoryStore {
|
||||||
.get(room_id, sender_key, session_id))
|
.get(room_id, sender_key, session_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tracked_users(&self) -> &HashSet<String> {
|
fn tracked_users(&self) -> &HashSet<UserId> {
|
||||||
&self.tracked_users
|
&self.tracked_users
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
|
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
||||||
Ok(self.tracked_users.insert(user.to_string()))
|
Ok(self.tracked_users.insert(user.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_device(&self, user_id: &str, device_id: &str) -> Result<Option<Device>> {
|
async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>> {
|
||||||
Ok(self.devices.get(user_id, device_id))
|
Ok(self.devices.get(user_id, device_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_user_devices(&self, user_id: &str) -> Result<UserDevices> {
|
async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
|
||||||
Ok(self.devices.user_devices(user_id))
|
Ok(self.devices.user_devices(user_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ use tokio::sync::Mutex;
|
||||||
use super::device::Device;
|
use super::device::Device;
|
||||||
use super::memory_stores::UserDevices;
|
use super::memory_stores::UserDevices;
|
||||||
use super::olm::{Account, InboundGroupSession, Session};
|
use super::olm::{Account, InboundGroupSession, Session};
|
||||||
|
use crate::identifiers::{RoomId, UserId};
|
||||||
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||||
|
|
||||||
pub mod memorystore;
|
pub mod memorystore;
|
||||||
|
@ -75,13 +76,13 @@ pub trait CryptoStore: Debug + Send + Sync {
|
||||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
|
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
&mut self,
|
&mut self,
|
||||||
room_id: &str,
|
room_id: &RoomId,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
|
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>;
|
||||||
fn tracked_users(&self) -> &HashSet<String>;
|
fn tracked_users(&self) -> &HashSet<UserId>;
|
||||||
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool>;
|
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
|
||||||
async fn save_device(&self, device: Device) -> Result<()>;
|
async fn save_device(&self, device: Device) -> Result<()>;
|
||||||
async fn get_device(&self, user_id: &str, device_id: &str) -> Result<Option<Device>>;
|
async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>>;
|
||||||
async fn get_user_devices(&self, user_id: &str) -> Result<UserDevices>;
|
async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
use std::convert::TryFrom;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::result::Result as StdResult;
|
use std::result::Result as StdResult;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -29,6 +30,7 @@ use zeroize::Zeroizing;
|
||||||
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session};
|
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session};
|
||||||
use crate::crypto::device::Device;
|
use crate::crypto::device::Device;
|
||||||
use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices};
|
use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices};
|
||||||
|
use crate::identifiers::{RoomId, UserId};
|
||||||
|
|
||||||
pub struct SqliteStore {
|
pub struct SqliteStore {
|
||||||
user_id: Arc<String>,
|
user_id: Arc<String>,
|
||||||
|
@ -39,14 +41,14 @@ pub struct SqliteStore {
|
||||||
inbound_group_sessions: GroupSessionStore,
|
inbound_group_sessions: GroupSessionStore,
|
||||||
connection: Arc<Mutex<SqliteConnection>>,
|
connection: Arc<Mutex<SqliteConnection>>,
|
||||||
pickle_passphrase: Option<Zeroizing<String>>,
|
pickle_passphrase: Option<Zeroizing<String>>,
|
||||||
tracked_users: HashSet<String>,
|
tracked_users: HashSet<UserId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
|
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
|
||||||
|
|
||||||
impl SqliteStore {
|
impl SqliteStore {
|
||||||
pub async fn open<P: AsRef<Path>>(
|
pub async fn open<P: AsRef<Path>>(
|
||||||
user_id: &str,
|
user_id: &UserId,
|
||||||
device_id: &str,
|
device_id: &str,
|
||||||
path: P,
|
path: P,
|
||||||
) -> Result<SqliteStore> {
|
) -> Result<SqliteStore> {
|
||||||
|
@ -54,7 +56,7 @@ impl SqliteStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn open_with_passphrase<P: AsRef<Path>>(
|
pub async fn open_with_passphrase<P: AsRef<Path>>(
|
||||||
user_id: &str,
|
user_id: &UserId,
|
||||||
device_id: &str,
|
device_id: &str,
|
||||||
path: P,
|
path: P,
|
||||||
passphrase: String,
|
passphrase: String,
|
||||||
|
@ -69,7 +71,7 @@ impl SqliteStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn open_helper<P: AsRef<Path>>(
|
async fn open_helper<P: AsRef<Path>>(
|
||||||
user_id: &str,
|
user_id: &UserId,
|
||||||
device_id: &str,
|
device_id: &str,
|
||||||
path: P,
|
path: P,
|
||||||
passphrase: Option<Zeroizing<String>>,
|
passphrase: Option<Zeroizing<String>>,
|
||||||
|
@ -78,7 +80,7 @@ impl SqliteStore {
|
||||||
|
|
||||||
let connection = SqliteConnection::connect(url.as_ref()).await?;
|
let connection = SqliteConnection::connect(url.as_ref()).await?;
|
||||||
let store = SqliteStore {
|
let store = SqliteStore {
|
||||||
user_id: Arc::new(user_id.to_owned()),
|
user_id: Arc::new(user_id.to_string()),
|
||||||
device_id: Arc::new(device_id.to_owned()),
|
device_id: Arc::new(device_id.to_owned()),
|
||||||
account_id: None,
|
account_id: None,
|
||||||
sessions: SessionStore::new(),
|
sessions: SessionStore::new(),
|
||||||
|
@ -230,7 +232,7 @@ impl SqliteStore {
|
||||||
self.get_pickle_mode(),
|
self.get_pickle_mode(),
|
||||||
sender_key.to_string(),
|
sender_key.to_string(),
|
||||||
signing_key.to_owned(),
|
signing_key.to_owned(),
|
||||||
room_id.to_owned(),
|
RoomId::try_from(room_id.as_str()).unwrap(),
|
||||||
)?)
|
)?)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
||||||
|
@ -302,8 +304,8 @@ impl CryptoStore for SqliteStore {
|
||||||
device_id = ?2
|
device_id = ?2
|
||||||
",
|
",
|
||||||
)
|
)
|
||||||
.bind(&*self.user_id)
|
.bind(&*self.user_id.to_string())
|
||||||
.bind(&*self.device_id)
|
.bind(&*self.device_id.to_string())
|
||||||
.bind(&pickle)
|
.bind(&pickle)
|
||||||
.bind(acc.shared)
|
.bind(acc.shared)
|
||||||
.execute(&mut *connection)
|
.execute(&mut *connection)
|
||||||
|
@ -311,8 +313,8 @@ impl CryptoStore for SqliteStore {
|
||||||
|
|
||||||
let account_id: (i64,) =
|
let account_id: (i64,) =
|
||||||
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
|
query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?")
|
||||||
.bind(&*self.user_id)
|
.bind(&*self.user_id.to_string())
|
||||||
.bind(&*self.device_id)
|
.bind(&*self.device_id.to_string())
|
||||||
.fetch_one(&mut *connection)
|
.fetch_one(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -383,7 +385,7 @@ impl CryptoStore for SqliteStore {
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
.bind(&session.sender_key)
|
.bind(&session.sender_key)
|
||||||
.bind(&session.signing_key)
|
.bind(&session.signing_key)
|
||||||
.bind(&session.room_id)
|
.bind(&session.room_id.to_string())
|
||||||
.bind(&pickle)
|
.bind(&pickle)
|
||||||
.execute(&mut *connection)
|
.execute(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -393,7 +395,7 @@ impl CryptoStore for SqliteStore {
|
||||||
|
|
||||||
async fn get_inbound_group_session(
|
async fn get_inbound_group_session(
|
||||||
&mut self,
|
&mut self,
|
||||||
room_id: &str,
|
room_id: &RoomId,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> {
|
||||||
|
@ -402,19 +404,19 @@ impl CryptoStore for SqliteStore {
|
||||||
.get(room_id, sender_key, session_id))
|
.get(room_id, sender_key, session_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tracked_users(&self) -> &HashSet<String> {
|
fn tracked_users(&self) -> &HashSet<UserId> {
|
||||||
&self.tracked_users
|
&self.tracked_users
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_user_for_tracking(&mut self, user: &str) -> Result<bool> {
|
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
||||||
Ok(self.tracked_users.insert(user.to_string()))
|
Ok(self.tracked_users.insert(user.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_device(&self, _user_id: &str, _device_id: &str) -> Result<Option<Device>> {
|
async fn get_device(&self, _user_id: &UserId, _device_id: &str) -> Result<Option<Device>> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_user_devices(&self, _user_id: &str) -> Result<UserDevices> {
|
async fn get_user_devices(&self, _user_id: &UserId) -> Result<UserDevices> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -442,7 +444,9 @@ mod test {
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use super::{Account, CryptoStore, InboundGroupSession, Session, SqliteStore};
|
use super::{
|
||||||
|
Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom, UserId,
|
||||||
|
};
|
||||||
|
|
||||||
static USER_ID: &str = "@example:localhost";
|
static USER_ID: &str = "@example:localhost";
|
||||||
static DEVICE_ID: &str = "DEVICEID";
|
static DEVICE_ID: &str = "DEVICEID";
|
||||||
|
@ -450,7 +454,7 @@ mod test {
|
||||||
async fn get_store() -> SqliteStore {
|
async fn get_store() -> SqliteStore {
|
||||||
let tmpdir = tempdir().unwrap();
|
let tmpdir = tempdir().unwrap();
|
||||||
let tmpdir_path = tmpdir.path().to_str().unwrap();
|
let tmpdir_path = tmpdir.path().to_str().unwrap();
|
||||||
SqliteStore::open(USER_ID, DEVICE_ID, tmpdir_path)
|
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, tmpdir_path)
|
||||||
.await
|
.await
|
||||||
.expect("Can't create store")
|
.expect("Can't create store")
|
||||||
}
|
}
|
||||||
|
@ -501,7 +505,7 @@ mod test {
|
||||||
async fn create_store() {
|
async fn create_store() {
|
||||||
let tmpdir = tempdir().unwrap();
|
let tmpdir = tempdir().unwrap();
|
||||||
let tmpdir_path = tmpdir.path().to_str().unwrap();
|
let tmpdir_path = tmpdir.path().to_str().unwrap();
|
||||||
let _ = SqliteStore::open("@example:localhost", "DEVICEID", tmpdir_path)
|
let _ = SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), "DEVICEID", tmpdir_path)
|
||||||
.await
|
.await
|
||||||
.expect("Can't create store");
|
.expect("Can't create store");
|
||||||
}
|
}
|
||||||
|
@ -626,7 +630,7 @@ mod test {
|
||||||
let session = InboundGroupSession::new(
|
let session = InboundGroupSession::new(
|
||||||
identity_keys.curve25519(),
|
identity_keys.curve25519(),
|
||||||
identity_keys.ed25519(),
|
identity_keys.ed25519(),
|
||||||
"!test:localhost",
|
&RoomId::try_from("!test:localhost").unwrap(),
|
||||||
&outbound_session.session_key(),
|
&outbound_session.session_key(),
|
||||||
)
|
)
|
||||||
.expect("Can't create session");
|
.expect("Can't create session");
|
||||||
|
@ -647,7 +651,7 @@ mod test {
|
||||||
let session = InboundGroupSession::new(
|
let session = InboundGroupSession::new(
|
||||||
identity_keys.curve25519(),
|
identity_keys.curve25519(),
|
||||||
identity_keys.ed25519(),
|
identity_keys.ed25519(),
|
||||||
"!test:localhost",
|
&RoomId::try_from("!test:localhost").unwrap(),
|
||||||
&outbound_session.session_key(),
|
&outbound_session.session_key(),
|
||||||
)
|
)
|
||||||
.expect("Can't create session");
|
.expect("Can't create session");
|
||||||
|
|
|
@ -65,7 +65,7 @@ use tokio::sync::Mutex;
|
||||||
/// } = event.lock().await.deref()
|
/// } = event.lock().await.deref()
|
||||||
/// {
|
/// {
|
||||||
/// let rooms = room.lock().await;
|
/// let rooms = room.lock().await;
|
||||||
/// let member = rooms.members.get(&sender.to_string()).unwrap();
|
/// let member = rooms.members.get(&sender).unwrap();
|
||||||
/// println!(
|
/// println!(
|
||||||
/// "{}: {}",
|
/// "{}: {}",
|
||||||
/// member
|
/// member
|
||||||
|
|
Loading…
Reference in New Issue