diff --git a/Cargo.toml b/Cargo.toml index c160a20d..5bba864e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ serde_json = "1.0.51" # Ruma dependencies js_int = "0.1.4" -ruma-api = "0.15.0" +ruma-api = "0.15.1" ruma-client-api = { git = "https://github.com/matrix-org/ruma-client-api/", version = "0.7.0" } ruma-events = { git = "https://github.com/matrix-org/ruma-events", version = "0.18.0" } ruma-identifiers = "0.14.1" diff --git a/src/async_client.rs b/src/async_client.rs index 5e298a8b..f206b9e9 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -561,7 +561,7 @@ impl AsyncClient { let mut response = self.send(request).await?; for (room_id, room) in &mut response.rooms.join { - let _matrix_room = { + let matrix_room = { let mut client = self.base_client.write().await; for event in &room.state.events { if let EventResult::Ok(e) = event { @@ -572,6 +572,9 @@ impl AsyncClient { client.get_or_create_room(&room_id).clone() }; + // RoomSummary contains information for calculating room name + matrix_room.write().await.set_room_summary(&room.summary); + // re looping is not ideal here for event in &mut room.state.events { if let EventResult::Ok(e) = event { @@ -754,7 +757,7 @@ impl AsyncClient { } } - async fn send( + async fn send + std::fmt::Debug>( &self, request: Request, ) -> Result<::Incoming> @@ -815,21 +818,20 @@ impl AsyncClient { trace!("Got response: {:?}", response); let status = response.status(); - let mut http_response = HttpResponse::builder().status(status); - let headers = http_response.headers_mut().unwrap(); + let mut http_builder = HttpResponse::builder().status(status); + let headers = http_builder.headers_mut().unwrap(); for (k, v) in response.headers_mut().drain() { if let Some(key) = k { headers.insert(key, v); } } - let body = response.bytes().await?.as_ref().to_owned(); - let http_response = http_response.body(body).unwrap(); - let response = ::Incoming::try_from(http_response) - .expect("Can't convert http response into ruma response"); + let http_response = http_builder.body(body).unwrap(); - Ok(response) + Ok(::Incoming::try_from( + http_response, + )?) } /// Send a room message to the homeserver. @@ -1106,7 +1108,9 @@ mod test { use crate::test_builder::EventBuilder; + use mockito::mock; use std::convert::TryFrom; + use std::str::FromStr; #[tokio::test] async fn client_runner() { @@ -1168,4 +1172,44 @@ mod test { &Url::parse(&mockito::server_url()).unwrap() ); } + + #[tokio::test] + async fn login_error() { + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + + let _m = mock("POST", "/_matrix/client/r0/login") + .with_status(403) + .with_body_from_file("tests/data/login_response_error.json") + .create(); + + let client = AsyncClient::new(homeserver, None).unwrap(); + + if let Err(err) = client.login("example", "wordpass", None, None).await { + if let crate::Error::RumaResponse(ruma_api::error::FromHttpResponseError::Http( + ruma_api::error::ServerError::Known(ruma_client_api::error::Error { + kind, + message, + status_code, + }), + )) = err + { + if let ruma_client_api::error::ErrorKind::Forbidden = kind { + } else { + panic!( + "found the wrong `ErrorKind` {:?}, expected `Forbidden", + kind + ); + } + assert_eq!(message, "Invalid password".to_string()); + assert_eq!(status_code, http::StatusCode::from_u16(403).unwrap()); + } else { + panic!( + "found the wrong `Error` type {:?}, expected `Error::RumaResponse", + err + ); + } + } else { + panic!("this request should return an `Err` variant") + } + } } diff --git a/src/base_client.rs b/src/base_client.rs index 33a0314f..30cefc79 100644 --- a/src/base_client.rs +++ b/src/base_client.rs @@ -160,7 +160,7 @@ impl Client { pub(crate) async fn calculate_room_name(&self, room_id: &RoomId) -> Option { if let Some(room) = self.joined_rooms.get(room_id) { let room = room.read().await; - Some(room.room_name.calculate_name(room_id, &room.members)) + Some(room.room_name.calculate_name(&room.members)) } else { None } @@ -168,9 +168,9 @@ impl Client { pub(crate) async fn calculate_room_names(&self) -> Vec { let mut res = Vec::new(); - for (id, room) in &self.joined_rooms { + for (_id, room) in &self.joined_rooms { let room = room.read().await; - res.push(room.room_name.calculate_name(id, &room.members)) + res.push(room.room_name.calculate_name(&room.members)) } res } diff --git a/src/crypto/device.rs b/src/crypto/device.rs index 44bb8e3b..f6c34a26 100644 --- a/src/crypto/device.rs +++ b/src/crypto/device.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::collections::HashMap; -use std::sync::atomic::AtomicBool; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use atomic::Atomic; @@ -26,6 +26,8 @@ use crate::identifiers::{DeviceId, UserId}; pub struct Device { user_id: Arc, device_id: Arc, + // TODO the algorithm and the keys might change, so we can't make them read + // only here. Perhaps dashmap and a rwlock on the algorithms. algorithms: Arc>, keys: Arc>, display_name: Arc>, @@ -33,26 +35,86 @@ pub struct Device { trust_state: Arc>, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] +/// The trust state of a device. pub enum TrustState { - Verified, - BlackListed, - Ignored, - Unset, + /// The device has been verified and is trusted. + Verified = 0, + /// The device been blacklisted from communicating. + BlackListed = 1, + /// The trust state of the device is being ignored. + Ignored = 2, + /// The trust state is unset. + Unset = 3, +} + +impl From for TrustState { + fn from(state: i64) -> Self { + match state { + 0 => TrustState::Verified, + 1 => TrustState::BlackListed, + 2 => TrustState::Ignored, + 3 => TrustState::Unset, + _ => TrustState::Unset, + } + } } impl Device { - pub fn device_id(&self) -> &DeviceId { - &self.device_id + /// Create a new Device. + pub fn new( + user_id: UserId, + device_id: DeviceId, + display_name: Option, + trust_state: TrustState, + algorithms: Vec, + keys: HashMap, + ) -> Self { + Device { + user_id: Arc::new(user_id), + device_id: Arc::new(device_id), + display_name: Arc::new(display_name), + trust_state: Arc::new(Atomic::new(trust_state)), + algorithms: Arc::new(algorithms), + keys: Arc::new(keys), + deleted: Arc::new(AtomicBool::new(false)), + } } + /// The user id of the device owner. pub fn user_id(&self) -> &UserId { &self.user_id } - pub fn keys(&self, algorithm: &KeyAlgorithm) -> Option<&String> { + /// The unique ID of the device. + pub fn device_id(&self) -> &DeviceId { + &self.device_id + } + + /// Get the human readable name of the device. + pub fn display_name(&self) -> &Option { + &self.display_name + } + + /// Get the key of the given key algorithm belonging to this device. + pub fn get_key(&self, algorithm: &KeyAlgorithm) -> Option<&String> { self.keys.get(algorithm) } + + /// Get a map containing all the device keys. + pub fn keys(&self) -> &HashMap { + &self.keys + } + + /// Get the trust state of the device. + pub fn trust_state(&self) -> TrustState { + self.trust_state.load(Ordering::Relaxed) + } + + /// Get the list of algorithms this device supports. + pub fn algorithms(&self) -> &[Algorithm] { + &self.algorithms + } } impl From<&DeviceKeys> for Device { @@ -93,7 +155,7 @@ pub(crate) mod test { use std::convert::{From, TryFrom}; use crate::api::r0::keys::{DeviceKeys, KeyAlgorithm}; - use crate::crypto::device::Device; + use crate::crypto::device::{Device, TrustState}; use crate::identifiers::UserId; pub(crate) fn get_device() -> Device { @@ -136,12 +198,17 @@ pub(crate) mod test { assert_eq!(&user_id, device.user_id()); assert_eq!(device_id, device.device_id()); assert_eq!(device.algorithms.len(), 2); + assert_eq!(TrustState::Unset, device.trust_state()); assert_eq!( - device.keys(&KeyAlgorithm::Curve25519).unwrap(), + "Alice's mobile phone", + device.display_name().as_ref().unwrap() + ); + assert_eq!( + device.get_key(&KeyAlgorithm::Curve25519).unwrap(), "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4" ); assert_eq!( - device.keys(&KeyAlgorithm::Ed25519).unwrap(), + device.get_key(&KeyAlgorithm::Ed25519).unwrap(), "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM" ); } diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 703f6836..73f230f6 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -201,7 +201,7 @@ impl OlmMachine { let user_devices = self.store.get_user_devices(user_id).await.unwrap(); for device in user_devices.devices() { - let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { + let sender_key = if let Some(k) = device.get_key(&KeyAlgorithm::Curve25519) { k } else { continue; @@ -276,7 +276,7 @@ impl OlmMachine { continue; }; - let signing_key = if let Some(k) = device.keys(&KeyAlgorithm::Ed25519) { + let signing_key = if let Some(k) = device.get_key(&KeyAlgorithm::Ed25519) { k } else { warn!( @@ -298,7 +298,7 @@ impl OlmMachine { continue; } - let curve_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { + let curve_key = if let Some(k) = device.get_key(&KeyAlgorithm::Curve25519) { k } else { warn!( @@ -326,7 +326,7 @@ impl OlmMachine { } }; - if let Err(e) = self.store.add_and_save_session(session).await { + if let Err(e) = self.store.save_session(session).await { error!("Failed to store newly created Olm session {}", e); continue; } @@ -703,7 +703,7 @@ impl OlmMachine { }; let plaintext = session.decrypt(message).await?; - self.store.add_and_save_session(session).await?; + self.store.save_session(session).await?; plaintext }; @@ -865,10 +865,10 @@ impl OlmMachine { let identity_keys = self.account.identity_keys(); let recipient_signing_key = recipient_device - .keys(&KeyAlgorithm::Ed25519) + .get_key(&KeyAlgorithm::Ed25519) .ok_or(OlmError::MissingSigningKey)?; let recipient_sender_key = recipient_device - .keys(&KeyAlgorithm::Curve25519) + .get_key(&KeyAlgorithm::Curve25519) .ok_or(OlmError::MissingSigningKey)?; let payload = json!({ @@ -957,7 +957,7 @@ impl OlmMachine { for user_id in users { for device in self.store.get_user_devices(user_id).await?.devices() { - let sender_key = if let Some(k) = device.keys(&KeyAlgorithm::Curve25519) { + let sender_key = if let Some(k) = device.get_key(&KeyAlgorithm::Curve25519) { k } else { warn!( diff --git a/src/crypto/memory_stores.rs b/src/crypto/memory_stores.rs index b0e39bb6..2510edb9 100644 --- a/src/crypto/memory_stores.rs +++ b/src/crypto/memory_stores.rs @@ -37,7 +37,10 @@ impl SessionStore { } /// Add a session to the store. - pub async fn add(&mut self, session: Session) { + /// + /// Returns true if the the session was added, false if the session was + /// already in the store. + pub async fn add(&mut self, session: Session) -> bool { if !self.entries.contains_key(&*session.sender_key) { self.entries.insert( session.sender_key.to_string(), @@ -45,7 +48,13 @@ impl SessionStore { ); } let sessions = self.entries.get_mut(&*session.sender_key).unwrap(); - sessions.lock().await.push(session); + + if !sessions.lock().await.contains(&session) { + sessions.lock().await.push(session); + true + } else { + false + } } /// Get all the sessions that belong to the given sender key. @@ -75,6 +84,9 @@ impl GroupSessionStore { } /// Add a inbound group session to the store. + /// + /// Returns true if the the session was added, false if the session was + /// already in the store. pub fn add(&mut self, session: InboundGroupSession) -> bool { if !self.entries.contains_key(&session.room_id) { let room_id = &*session.room_id; @@ -91,7 +103,7 @@ impl GroupSessionStore { let sender_map = room_map.get_mut(&*session.sender_key).unwrap(); let ret = sender_map.insert(session.session_id().to_owned(), session); - ret.is_some() + ret.is_none() } /// Get a inbound group session from our store. @@ -163,7 +175,7 @@ impl DeviceStore { device_map .insert(device.device_id().to_owned(), device) - .is_some() + .is_none() } /// Get the device with the given device_id and belonging to the given user. @@ -186,49 +198,22 @@ impl DeviceStore { #[cfg(test)] mod test { - use std::collections::HashMap; use std::convert::TryFrom; - use crate::api::r0::keys::SignedKey; use crate::crypto::device::test::get_device; use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore}; - use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; + use crate::crypto::olm::test::get_account_and_session; + use crate::crypto::olm::{InboundGroupSession, OutboundGroupSession}; use crate::identifiers::RoomId; - async fn get_account_and_session() -> (Account, Session) { - let alice = Account::new(); - - let bob = Account::new(); - - bob.generate_one_time_keys(1).await; - let one_time_key = bob - .one_time_keys() - .await - .curve25519() - .iter() - .nth(0) - .unwrap() - .1 - .to_owned(); - let one_time_key = SignedKey { - key: one_time_key, - signatures: HashMap::new(), - }; - let sender_key = bob.identity_keys().curve25519().to_owned(); - let session = alice - .create_outbound_session(&sender_key, &one_time_key) - .await - .unwrap(); - - (alice, session) - } - #[tokio::test] async fn test_session_store() { - let (account, session) = get_account_and_session().await; + let (_, session) = get_account_and_session().await; let mut store = SessionStore::new(); - store.add(session.clone()).await; + + assert!(store.add(session.clone()).await); + assert!(!store.add(session.clone()).await); let sessions = store.get(&session.sender_key).unwrap(); let sessions = sessions.lock().await; @@ -240,7 +225,7 @@ mod test { #[tokio::test] async fn test_session_store_bulk_storing() { - let (account, session) = get_account_and_session().await; + let (_, session) = get_account_and_session().await; let mut store = SessionStore::new(); store.set_for_sender(&session.sender_key, vec![session.clone()]); @@ -255,7 +240,6 @@ mod test { #[tokio::test] async fn test_group_session_store() { - let alice = Account::new(); let room_id = RoomId::try_from("!test:localhost").unwrap(); let outbound = OutboundGroupSession::new(&room_id); @@ -287,8 +271,8 @@ mod test { let device = get_device(); let store = DeviceStore::new(); - assert!(!store.add(device.clone())); assert!(store.add(device.clone())); + assert!(!store.add(device.clone())); let loaded_device = store.get(device.user_id(), device.device_id()).unwrap(); diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 3dff2085..614f5bda 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -613,14 +613,42 @@ impl std::fmt::Debug for OutboundGroupSession { } #[cfg(test)] -mod test { - use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession}; +pub(crate) mod test { + use crate::crypto::olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; use crate::identifiers::RoomId; use olm_rs::session::OlmMessage; use ruma_client_api::r0::keys::SignedKey; use std::collections::HashMap; use std::convert::TryFrom; + pub(crate) async fn get_account_and_session() -> (Account, Session) { + let alice = Account::new(); + + let bob = Account::new(); + + bob.generate_one_time_keys(1).await; + let one_time_key = bob + .one_time_keys() + .await + .curve25519() + .iter() + .nth(0) + .unwrap() + .1 + .to_owned(); + let one_time_key = SignedKey { + key: one_time_key, + signatures: HashMap::new(), + }; + let sender_key = bob.identity_keys().curve25519().to_owned(); + let session = alice + .create_outbound_session(&sender_key, &one_time_key) + .await + .unwrap(); + + (alice, session) + } + #[test] fn account_creation() { let account = Account::new(); @@ -724,7 +752,6 @@ mod test { #[tokio::test] async fn group_session_creation() { - let alice = Account::new(); let room_id = RoomId::try_from("!test:localhost").unwrap(); let outbound = OutboundGroupSession::new(&room_id); diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 6daa1dba..0e889786 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -21,7 +21,7 @@ use tokio::sync::Mutex; use super::{Account, CryptoStore, InboundGroupSession, Result, Session}; use crate::crypto::device::Device; use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; -use crate::identifiers::{RoomId, UserId}; +use crate::identifiers::{DeviceId, RoomId, UserId}; #[derive(Debug)] pub struct MemoryStore { @@ -52,11 +52,7 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn save_session(&mut self, _: Session) -> Result<()> { - Ok(()) - } - - async fn add_and_save_session(&mut self, session: Session) -> Result<()> { + async fn save_session(&mut self, session: Session) -> Result<()> { self.sessions.add(session).await; Ok(()) } @@ -88,7 +84,7 @@ impl CryptoStore for MemoryStore { Ok(self.tracked_users.insert(user.clone())) } - async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result> { + async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { Ok(self.devices.get(user_id, device_id)) } @@ -101,3 +97,102 @@ impl CryptoStore for MemoryStore { Ok(()) } } + +#[cfg(test)] +mod test { + use std::convert::TryFrom; + + use crate::crypto::device::test::get_device; + use crate::crypto::olm::test::get_account_and_session; + use crate::crypto::olm::{InboundGroupSession, OutboundGroupSession}; + use crate::crypto::store::memorystore::MemoryStore; + use crate::crypto::store::CryptoStore; + use crate::identifiers::RoomId; + + #[tokio::test] + async fn test_session_store() { + let (account, session) = get_account_and_session().await; + let mut store = MemoryStore::new(); + + assert!(store.load_account().await.unwrap().is_none()); + store.save_account(account).await.unwrap(); + + store.save_session(session.clone()).await.unwrap(); + + let sessions = store + .get_sessions(&session.sender_key) + .await + .unwrap() + .unwrap(); + let sessions = sessions.lock().await; + + let loaded_session = &sessions[0]; + + assert_eq!(&session, loaded_session); + } + + #[tokio::test] + async fn test_group_session_store() { + let room_id = RoomId::try_from("!test:localhost").unwrap(); + + let outbound = OutboundGroupSession::new(&room_id); + let inbound = InboundGroupSession::new( + "test_key", + "test_key", + &room_id, + outbound.session_key().await, + ) + .unwrap(); + + let mut store = MemoryStore::new(); + store + .save_inbound_group_session(inbound.clone()) + .await + .unwrap(); + + let loaded_session = store + .get_inbound_group_session(&room_id, "test_key", outbound.session_id()) + .await + .unwrap() + .unwrap(); + assert_eq!(inbound, loaded_session); + } + + #[tokio::test] + async fn test_device_store() { + let device = get_device(); + let store = MemoryStore::new(); + + store.save_device(device.clone()).await.unwrap(); + + let loaded_device = store + .get_device(device.user_id(), device.device_id()) + .await + .unwrap() + .unwrap(); + + assert_eq!(device, loaded_device); + + let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); + + assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); + assert_eq!(user_devices.devices().nth(0).unwrap(), &device); + + let loaded_device = user_devices.get(device.device_id()).unwrap(); + + assert_eq!(device, loaded_device); + } + + #[tokio::test] + async fn test_tracked_users() { + let device = get_device(); + let mut store = MemoryStore::new(); + + assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); + assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap()); + + let tracked_users = store.tracked_users(); + + tracked_users.contains(device.user_id()); + } +} diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index d059b3a7..7181aff1 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -26,7 +26,7 @@ use tokio::sync::Mutex; use super::device::Device; use super::memory_stores::UserDevices; use super::olm::{Account, InboundGroupSession, Session}; -use crate::identifiers::{RoomId, UserId}; +use crate::identifiers::{DeviceId, RoomId, UserId}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; pub mod memorystore; @@ -65,14 +65,48 @@ pub type Result = std::result::Result; #[async_trait] pub trait CryptoStore: Debug + Send + Sync { + /// Load an account that was previously stored. async fn load_account(&mut self) -> Result>; + + /// Save the given account in the store. + /// + /// # Arguments + /// + /// * `account` - The account that should be stored. async fn save_account(&mut self, account: Account) -> Result<()>; + /// Save the given session in the store. + /// + /// # Arguments + /// + /// * `session` - The session that should be stored. async fn save_session(&mut self, session: Session) -> Result<()>; - async fn add_and_save_session(&mut self, session: Session) -> Result<()>; + + /// Get all the sessions that belong to the given sender key. + /// + /// # Arguments + /// + /// * `sender_key` - The sender key that was used to establish the sessions. async fn get_sessions(&mut self, sender_key: &str) -> Result>>>>; + /// Save the given inbound group session in the store. + /// + /// If the session wasn't already in the store true is returned, false + /// otherwise. + /// + /// # Arguments + /// + /// * `session` - The session that should be stored. async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result; + + /// Get the inbound group session from our store. + /// + /// # Arguments + /// * `room_id` - The room id of the room that the session belongs to. + /// + /// * `sender_key` - The sender key that sent us the session. + /// + /// * `session_id` - The unique id of the session. async fn get_inbound_group_session( &mut self, room_id: &RoomId, @@ -80,10 +114,39 @@ pub trait CryptoStore: Debug + Send + Sync { session_id: &str, ) -> Result>; + /// Get the set of tracked users. fn tracked_users(&self) -> &HashSet; + + /// Add an user for tracking. + /// + /// Returns true if the user wasn't already tracked, false otherwise. + /// + /// # Arguments + /// + /// * `user` - The user that should be marked as tracked. async fn add_user_for_tracking(&mut self, user: &UserId) -> Result; + /// Save the given device in the store. + /// + /// # Arguments + /// + /// * `device` - The device that should be stored. async fn save_device(&self, device: Device) -> Result<()>; - async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result>; + + /// Get the device for the given user with the given device id. + /// + /// # Arguments + /// + /// * `user_id` - The user that the device belongs to. + /// + /// * `device_id` - The unique id of the device. + async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result>; + + /// Get all the devices of the given user. + /// + /// + /// # Arguments + /// + /// * `user_id` - The user for which we should get all the devices. async fn get_user_devices(&self, user_id: &UserId) -> Result; } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 99a859f4..237ea913 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; +use std::mem; use std::path::{Path, PathBuf}; use std::result::Result as StdResult; use std::sync::Arc; @@ -28,20 +29,25 @@ use tokio::sync::Mutex; use zeroize::Zeroizing; use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; -use crate::crypto::device::Device; -use crate::crypto::memory_stores::{GroupSessionStore, SessionStore, UserDevices}; -use crate::identifiers::{RoomId, UserId}; +use crate::api::r0::keys::KeyAlgorithm; +use crate::crypto::device::{Device, TrustState}; +use crate::crypto::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; +use crate::events::Algorithm; +use crate::identifiers::{DeviceId, RoomId, UserId}; pub struct SqliteStore { user_id: Arc, device_id: Arc, account_id: Option, path: PathBuf, + sessions: SessionStore, inbound_group_sessions: GroupSessionStore, + devices: DeviceStore, + tracked_users: HashSet, + connection: Arc>, pickle_passphrase: Option>, - tracked_users: HashSet, } static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; @@ -85,6 +91,7 @@ impl SqliteStore { account_id: None, sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), + devices: DeviceStore::new(), path: path.as_ref().to_owned(), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: passphrase, @@ -125,7 +132,7 @@ impl SqliteStore { ON DELETE CASCADE ); - CREATE INDEX "olmsessions_account_id" ON "sessions" ("account_id"); + CREATE INDEX IF NOT EXISTS "olmsessions_account_id" ON "sessions" ("account_id"); "#, ) .await?; @@ -144,7 +151,62 @@ impl SqliteStore { ON DELETE CASCADE ); - CREATE INDEX "olm_groups_sessions_account_id" ON "inbound_group_sessions" ("account_id"); + CREATE INDEX IF NOT EXISTS "olm_groups_sessions_account_id" ON "inbound_group_sessions" ("account_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS devices ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "user_id" TEXT NOT NULL, + "device_id" TEXT NOT NULL, + "display_name" TEXT, + "trust_state" INTEGER NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id,user_id,device_id) + ); + + CREATE INDEX IF NOT EXISTS "devices_account_id" ON "devices" ("account_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS algorithms ( + "id" INTEGER NOT NULL PRIMARY KEY, + "device_id" INTEGER NOT NULL, + "algorithm" TEXT NOT NULL, + FOREIGN KEY ("device_id") REFERENCES "devices" ("id") + ON DELETE CASCADE + UNIQUE(device_id, algorithm) + ); + + CREATE INDEX IF NOT EXISTS "algorithms_device_id" ON "algorithms" ("device_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS device_keys ( + "id" INTEGER NOT NULL PRIMARY KEY, + "device_id" INTEGER NOT NULL, + "algorithm" TEXT NOT NULL, + "key" TEXT NOT NULL, + FOREIGN KEY ("device_id") REFERENCES "devices" ("id") + ON DELETE CASCADE + UNIQUE(device_id, algorithm) + ); + + CREATE INDEX IF NOT EXISTS "device_keys_device_id" ON "device_keys" ("device_id"); "#, ) .await?; @@ -152,10 +214,7 @@ impl SqliteStore { Ok(()) } - async fn get_sessions_for( - &mut self, - sender_key: &str, - ) -> Result>>>> { + async fn lazy_load_sessions(&mut self, sender_key: &str) -> Result<()> { let loaded_sessions = self.sessions.get(sender_key).is_some(); if !loaded_sessions { @@ -166,6 +225,14 @@ impl SqliteStore { } } + Ok(()) + } + + async fn get_sessions_for( + &mut self, + sender_key: &str, + ) -> Result>>>> { + self.lazy_load_sessions(sender_key).await?; Ok(self.sessions.get(sender_key)) } @@ -238,6 +305,142 @@ impl SqliteStore { .collect::>>()?) } + async fn load_devices(&self) -> Result { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let rows: Vec<(i64, String, String, Option, i64)> = query_as( + "SELECT id, user_id, device_id, display_name, trust_state + FROM devices WHERE account_id = ?", + ) + .bind(account_id) + .fetch_all(&mut *connection) + .await?; + + let store = DeviceStore::new(); + + for row in rows { + let device_row_id = row.0; + let user_id = if let Ok(u) = UserId::try_from(&row.1 as &str) { + u + } else { + continue; + }; + + let device_id = &row.2.to_string(); + let display_name = &row.3; + let trust_state = TrustState::from(row.4); + + let algorithm_rows: Vec<(String,)> = + query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; + + let algorithms = algorithm_rows + .iter() + .map(|row| Algorithm::from(&row.0 as &str)) + .collect::>(); + + let key_rows: Vec<(String, String)> = + query_as("SELECT algorithm, key FROM device_keys WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; + + let mut keys = HashMap::new(); + + for row in key_rows { + let algorithm = if let Ok(a) = KeyAlgorithm::try_from(&row.0 as &str) { + a + } else { + continue; + }; + + let key = &row.1; + + keys.insert(algorithm, key.to_owned()); + } + + let device = Device::new( + user_id, + device_id.to_owned(), + display_name.clone(), + trust_state, + algorithms, + keys, + ); + + store.add(device); + } + + Ok(store) + } + + async fn save_device_helper(&self, device: Device) -> Result<()> { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + + let mut connection = self.connection.lock().await; + + query( + "INSERT INTO devices ( + account_id, user_id, device_id, + display_name, trust_state + ) VALUES (?1, ?2, ?3, ?4, ?5) + ON CONFLICT(account_id, user_id, device_id) DO UPDATE SET + display_name = excluded.display_name, + trust_state = excluded.trust_state + ", + ) + .bind(account_id) + .bind(&device.user_id().to_string()) + .bind(device.device_id()) + .bind(device.display_name()) + .bind(device.trust_state() as i64) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM devices + WHERE user_id = ? and device_id = ?", + ) + .bind(&device.user_id().to_string()) + .bind(device.device_id()) + .fetch_one(&mut *connection) + .await?; + + let device_row_id = row.0; + + for algorithm in device.algorithms() { + query( + "INSERT OR IGNORE INTO algorithms ( + device_id, algorithm + ) VALUES (?1, ?2) + ", + ) + .bind(device_row_id) + .bind(algorithm.to_string()) + .execute(&mut *connection) + .await?; + } + + for (key_algorithm, key) in device.keys() { + query( + "INSERT OR IGNORE INTO device_keys ( + device_id, algorithm, key + ) VALUES (?1, ?2, ?3) + ", + ) + .bind(device_row_id) + .bind(key_algorithm.to_string()) + .bind(key) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + fn get_pickle_mode(&self) -> PicklingMode { match &self.pickle_passphrase { Some(p) => PicklingMode::Encrypted { @@ -262,29 +465,33 @@ impl CryptoStore for SqliteStore { .fetch_optional(&mut *connection) .await?; - let result = match row { - Some((id, pickle, shared)) => { - self.account_id = Some(id); - Some(Account::from_pickle( - pickle, - self.get_pickle_mode(), - shared, - )?) - } - None => None, + let result = if let Some((id, pickle, shared)) = row { + self.account_id = Some(id); + Some(Account::from_pickle( + pickle, + self.get_pickle_mode(), + shared, + )?) + } else { + return Ok(None); }; drop(connection); - let mut sessions = self.load_inbound_group_sessions().await?; + let mut group_sessions = self.load_inbound_group_sessions().await?; - let _ = sessions + let _ = group_sessions .drain(..) .map(|s| { self.inbound_group_sessions.add(s); }) .collect::<()>(); + let devices = self.load_devices().await?; + mem::replace(&mut self.devices, devices); + + // TODO load the tracked users here as well. + Ok(result) } @@ -297,10 +504,8 @@ impl CryptoStore for SqliteStore { user_id, device_id, pickle, shared ) VALUES (?1, ?2, ?3, ?4) ON CONFLICT(user_id, device_id) DO UPDATE SET - pickle = ?3, - shared = ?4 - WHERE user_id = ?1 and - device_id = ?2 + pickle = excluded.pickle, + shared = excluded.shared ", ) .bind(&*self.user_id.to_string()) @@ -323,6 +528,9 @@ impl CryptoStore for SqliteStore { } async fn save_session(&mut self, session: Session) -> Result<()> { + self.lazy_load_sessions(&session.sender_key).await?; + self.sessions.add(session.clone()).await; + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let session_id = session.session_id(); @@ -349,12 +557,6 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn add_and_save_session(&mut self, session: Session) -> Result<()> { - self.sessions.add(session.clone()).await; - self.save_session(session).await?; - Ok(()) - } - async fn get_sessions(&mut self, sender_key: &str) -> Result>>>> { Ok(self.get_sessions_for(sender_key).await?) } @@ -371,8 +573,7 @@ impl CryptoStore for SqliteStore { room_id, pickle ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) ON CONFLICT(session_id) DO UPDATE SET - pickle = ?6 - WHERE session_id = ?1 + pickle = excluded.pickle ", ) .bind(session_id) @@ -403,37 +604,40 @@ impl CryptoStore for SqliteStore { } async fn add_user_for_tracking(&mut self, user: &UserId) -> Result { + // TODO save the tracked user to the database. Ok(self.tracked_users.insert(user.clone())) } - async fn get_device(&self, _user_id: &UserId, _device_id: &str) -> Result> { - todo!() + async fn save_device(&self, device: Device) -> Result<()> { + self.save_device_helper(device).await } - async fn get_user_devices(&self, _user_id: &UserId) -> Result { - todo!() + async fn get_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { + Ok(self.devices.get(user_id, device_id)) } - async fn save_device(&self, _device: Device) -> Result<()> { - todo!() + async fn get_user_devices(&self, user_id: &UserId) -> Result { + Ok(self.devices.user_devices(user_id)) } } +#[cfg_attr(tarpaulin, skip)] impl std::fmt::Debug for SqliteStore { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> { - write!( - fmt, - "SqliteStore {{ user_id: {}, device_id: {}, path: {:?} }}", - self.user_id, self.device_id, self.path - ) + fmt.debug_struct("SqliteStore") + .field("user_id", &self.user_id) + .field("device_id", &self.device_id) + .field("path", &self.path) + .finish() } } #[cfg(test)] mod test { + use crate::api::r0::keys::SignedKey; + use crate::crypto::device::test::get_device; use crate::crypto::olm::GroupSessionKey; use olm_rs::outbound_group_session::OlmOutboundGroupSession; - use ruma_client_api::r0::keys::SignedKey; use std::collections::HashMap; use tempfile::tempdir; @@ -444,23 +648,39 @@ mod test { static USER_ID: &str = "@example:localhost"; static DEVICE_ID: &str = "DEVICEID"; - async fn get_store() -> SqliteStore { + async fn get_store(passphrase: Option<&str>) -> (SqliteStore, tempfile::TempDir) { let tmpdir = tempdir().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap(); - SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, tmpdir_path) + + let user_id = &UserId::try_from(USER_ID).unwrap(); + + let store = if let Some(passphrase) = passphrase { + SqliteStore::open_with_passphrase( + &user_id, + DEVICE_ID, + tmpdir_path, + passphrase.to_owned(), + ) .await - .expect("Can't create store") + .expect("Can't create a passphrase protected store") + } else { + SqliteStore::open(&user_id, DEVICE_ID, tmpdir_path) + .await + .expect("Can't create store") + }; + + (store, tmpdir) } - async fn get_loaded_store() -> (Account, SqliteStore) { - let mut store = get_store().await; + async fn get_loaded_store() -> (Account, SqliteStore, tempfile::TempDir) { + let (mut store, dir) = get_store(None).await; let account = get_account(); store .save_account(account.clone()) .await .expect("Can't save account"); - (account, store) + (account, store, dir) } fn get_account() -> Account { @@ -506,7 +726,8 @@ mod test { #[tokio::test] async fn save_account() { - let mut store = get_store().await; + let (mut store, _dir) = get_store(None).await; + assert!(store.load_account().await.unwrap().is_none()); let account = get_account(); store @@ -517,7 +738,23 @@ mod test { #[tokio::test] async fn load_account() { - let mut store = get_store().await; + let (mut store, _dir) = get_store(None).await; + let account = get_account(); + + store + .save_account(account.clone()) + .await + .expect("Can't save account"); + + let loaded_account = store.load_account().await.expect("Can't load account"); + let loaded_account = loaded_account.unwrap(); + + assert_eq!(account, loaded_account); + } + + #[tokio::test] + async fn load_account_with_passphrase() { + let (mut store, _dir) = get_store(Some("secret_passphrase")).await; let account = get_account(); store @@ -533,7 +770,7 @@ mod test { #[tokio::test] async fn save_and_share_account() { - let mut store = get_store().await; + let (mut store, _dir) = get_store(None).await; let account = get_account(); store @@ -556,7 +793,7 @@ mod test { #[tokio::test] async fn save_session() { - let mut store = get_store().await; + let (mut store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; assert!(store.save_session(session.clone()).await.is_err()); @@ -571,7 +808,7 @@ mod test { #[tokio::test] async fn load_sessions() { - let mut store = get_store().await; + let (mut store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; store .save_account(account.clone()) @@ -590,7 +827,7 @@ mod test { #[tokio::test] async fn add_and_save_session() { - let mut store = get_store().await; + let (mut store, dir) = get_store(None).await; let (account, session) = get_account_and_session().await; let sender_key = session.sender_key.to_owned(); let session_id = session.session_id().to_owned(); @@ -599,7 +836,23 @@ mod test { .save_account(account.clone()) .await .expect("Can't save account"); - store.add_and_save_session(session).await.unwrap(); + store.save_session(session).await.unwrap(); + + let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); + let sessions_lock = sessions.lock().await; + let session = &sessions_lock[0]; + + assert_eq!(session_id, session.session_id()); + + drop(store); + + let mut store = + SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path()) + .await + .expect("Can't create store"); + + let loaded_account = store.load_account().await.unwrap().unwrap(); + assert_eq!(account, loaded_account); let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); let sessions_lock = sessions.lock().await; @@ -610,7 +863,7 @@ mod test { #[tokio::test] async fn save_inbound_group_session() { - let (account, mut store) = get_loaded_store().await; + let (account, mut store, _dir) = get_loaded_store().await; let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); @@ -630,7 +883,7 @@ mod test { #[tokio::test] async fn load_inbound_group_session() { - let (account, mut store) = get_loaded_store().await; + let (account, mut store, _dir) = get_loaded_store().await; let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); @@ -645,12 +898,67 @@ mod test { let session_id = session.session_id().to_owned(); store - .save_inbound_group_session(session) + .save_inbound_group_session(session.clone()) .await .expect("Can't save group session"); let sessions = store.load_inbound_group_sessions().await.unwrap(); assert_eq!(session_id, sessions[0].session_id()); + + let loaded_session = store + .get_inbound_group_session(&session.room_id, &session.sender_key, session.session_id()) + .await + .unwrap() + .unwrap(); + assert_eq!(session, loaded_session); + } + + #[tokio::test] + async fn test_tracked_users() { + let (_account, mut store, _dir) = get_loaded_store().await; + let device = get_device(); + + assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); + assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap()); + + let tracked_users = store.tracked_users(); + + tracked_users.contains(device.user_id()); + } + + #[tokio::test] + async fn device_saving() { + let (_account, store, dir) = get_loaded_store().await; + let device = get_device(); + + store.save_device(device.clone()).await.unwrap(); + + drop(store); + + let mut store = + SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path()) + .await + .expect("Can't create store"); + + store.load_account().await.unwrap(); + + let loaded_device = store + .get_device(device.user_id(), device.device_id()) + .await + .unwrap() + .unwrap(); + + assert_eq!(device, loaded_device); + + for algorithm in loaded_device.algorithms() { + assert!(device.algorithms().contains(algorithm)); + } + assert_eq!(device.algorithms().len(), loaded_device.algorithms().len()); + assert_eq!(device.keys(), loaded_device.keys()); + + let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); + assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); + assert_eq!(user_devices.devices().nth(0).unwrap(), &device); } } diff --git a/src/error.rs b/src/error.rs index 33e119a6..4af53574 100644 --- a/src/error.rs +++ b/src/error.rs @@ -56,7 +56,7 @@ pub enum Error { #[error(transparent)] IoError(#[from] IoError), #[cfg(feature = "encryption")] - /// An error occured durring a E2EE operation. + /// An error occurred during a E2EE operation. #[error(transparent)] OlmError(#[from] OlmError), } diff --git a/src/models/room.rs b/src/models/room.rs index 21ce143b..f6353583 100644 --- a/src/models/room.rs +++ b/src/models/room.rs @@ -18,6 +18,7 @@ use std::convert::TryFrom; use super::RoomMember; +use crate::api::r0::sync::sync_events::RoomSummary; use crate::events::collections::all::{RoomEvent, StateEvent}; use crate::events::presence::PresenceEvent; use crate::events::room::{ @@ -42,6 +43,17 @@ pub struct RoomName { canonical_alias: Option, /// List of `RoomAliasId`s the room has been given. aliases: Vec, + /// Users which can be used to generate a room name if the room does not have + /// one. Required if room name or canonical aliases are not set or empty. + pub heroes: Vec, + /// Number of users whose membership status is `join`. + /// Required if field has changed since last sync; otherwise, it may be + /// omitted. + pub joined_member_count: Option, + /// Number of users whose membership status is `invite`. + /// Required if field has changed since last sync; otherwise, it may be + /// omitted. + pub invited_member_count: Option, } #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -112,11 +124,7 @@ impl RoomName { true } - pub fn calculate_name( - &self, - room_id: &RoomId, - members: &HashMap, - ) -> String { + pub fn calculate_name(&self, members: &HashMap) -> String { // https://matrix.org/docs/spec/client_server/latest#calculating-the-display-name-for-a-room. // the order in which we check for a name ^^ if let Some(name) = &self.name { @@ -126,19 +134,22 @@ impl RoomName { } else if !self.aliases.is_empty() { self.aliases[0].alias().to_string() } else { - let mut names = members - .values() - .flat_map(|m| m.display_name.clone()) - .take(3) - .collect::>(); + let joined = self.joined_member_count.unwrap_or(UInt::max_value()); + let invited = self.invited_member_count.unwrap_or(UInt::max_value()); + let heroes = UInt::new(self.heroes.len() as u64).unwrap(); + let one = UInt::new(1).unwrap(); - if names.is_empty() { - // TODO implement the rest of display name for room spec - format!("Room {}", room_id) - } else { - // stabilize order + if heroes >= (joined + invited - one) { + let mut names = self.heroes.iter().take(3).cloned().collect::>(); names.sort(); names.join(", ") + } else if heroes < (joined + invited - one) && invited + joined > one { + let mut names = self.heroes.iter().take(3).cloned().collect::>(); + names.sort(); + // TODO what is the length the spec wants us to use here and in the `else` + format!("{}, and {} others", names.join(", "), (joined + invited)) + } else { + format!("Empty Room (was {} others)", members.len()) } } } @@ -169,7 +180,7 @@ impl Room { /// Return the display name of the room. pub fn calculate_name(&self) -> String { - self.room_name.calculate_name(&self.room_id, &self.members) + self.room_name.calculate_name(&self.members) } /// Is the room a encrypted room. @@ -239,6 +250,17 @@ impl Room { true } + pub(crate) fn set_room_summary(&mut self, summary: &RoomSummary) { + let RoomSummary { + heroes, + joined_member_count, + invited_member_count, + } = summary; + self.room_name.heroes = heroes.clone(); + self.room_name.invited_member_count = invited_member_count.clone(); + self.room_name.joined_member_count = joined_member_count.clone(); + } + /// Handle a room.member updating the room state if necessary. /// /// Returns true if the joined member list changed, false otherwise.