Merge branch 'crypto-improvements' into new-state-store

master
Damir Jelić 2020-10-24 20:16:59 +02:00
commit 962f725d63
31 changed files with 3388 additions and 1004 deletions

View File

@ -31,7 +31,7 @@ docs = ["encryption", "sqlite_cryptostore", "messages"]
async-trait = "0.1.41" async-trait = "0.1.41"
dashmap = { version = "3.11.10", optional = true } dashmap = { version = "3.11.10", optional = true }
http = "0.2.1" http = "0.2.1"
serde_json = "1.0.58" serde_json = "1.0.59"
thiserror = "1.0.21" thiserror = "1.0.21"
tracing = "0.1.21" tracing = "0.1.21"
url = "2.1.1" url = "2.1.1"
@ -73,7 +73,7 @@ async-std = { version = "1.6.5", features = ["unstable"] }
dirs = "3.0.1" dirs = "3.0.1"
matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" }
tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] }
serde_json = "1.0.58" serde_json = "1.0.59"
tracing-subscriber = "0.2.13" tracing-subscriber = "0.2.13"
tempfile = "3.1.0" tempfile = "3.1.0"
mockito = "0.27.0" mockito = "0.27.0"

View File

@ -0,0 +1,119 @@
use std::{
collections::BTreeMap,
env, io,
process::exit,
sync::atomic::{AtomicBool, Ordering},
};
use serde_json::json;
use url::Url;
use matrix_sdk::{
self,
api::r0::uiaa::AuthData,
identifiers::{user_id, UserId},
Client, ClientConfig, LoopCtrl, SyncSettings,
};
fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> AuthData<'a> {
let mut auth_parameters = BTreeMap::new();
let identifier = json!({
"type": "m.id.user",
"user": user,
});
auth_parameters.insert("identifier".to_owned(), identifier);
auth_parameters.insert("password".to_owned(), password.to_owned().into());
// This is needed because of https://github.com/matrix-org/synapse/issues/5665
auth_parameters.insert("user".to_owned(), user.as_str().into());
AuthData::DirectRequest {
kind: "m.login.password",
auth_parameters,
session,
}
}
async fn bootstrap(client: Client) {
println!("Bootstrapping a new cross signing identity, press enter to continue.");
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.expect("error: unable to read user input");
if let Err(e) = client.bootstrap_cross_signing(None).await {
if let Some(response) = e.uiaa_response() {
let auth_data = auth_data(
&user_id!("@example:localhost"),
"wordpass",
response.session.as_deref(),
);
client
.bootstrap_cross_signing(Some(auth_data))
.await
.expect("Couldn't bootstrap cross signing")
} else {
panic!("Error durign cross signing bootstrap {:#?}", e);
}
}
}
async fn login(
homeserver_url: String,
username: &str,
password: &str,
) -> Result<(), matrix_sdk::Error> {
let client_config = ClientConfig::new()
.disable_ssl_verification()
.proxy("http://localhost:8080")
.unwrap();
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client
.login(username, password, None, Some("rust-sdk"))
.await?;
let client_ref = &client;
let asked = AtomicBool::new(false);
let asked_ref = &asked;
client
.sync_with_callback(SyncSettings::new(), |_| async move {
let asked = asked_ref;
let client = &client_ref;
// Wait for sync to be done then ask the user to bootstrap.
if !asked.load(Ordering::SeqCst) {
tokio::spawn(bootstrap((*client).clone()));
}
asked.store(true, Ordering::SeqCst);
LoopCtrl::Continue
})
.await;
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), matrix_sdk::Error> {
tracing_subscriber::fmt::init();
let (homeserver_url, username, password) =
match (env::args().nth(1), env::args().nth(2), env::args().nth(3)) {
(Some(a), Some(b), Some(c)) => (a, b, c),
_ => {
eprintln!(
"Usage: {} <homeserver_url> <username> <password>",
env::args().next().unwrap()
);
exit(1)
}
};
login(homeserver_url, &username, &password).await
}

View File

@ -65,7 +65,7 @@ pub enum LoopCtrl {
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
account::register, account::register,
device::get_devices, device::{delete_devices, get_devices},
directory::{get_public_rooms, get_public_rooms_filtered}, directory::{get_public_rooms, get_public_rooms_filtered},
media::create_content, media::create_content,
membership::{ membership::{
@ -82,6 +82,7 @@ use matrix_sdk_common::{
typing::create_typing_event::{ typing::create_typing_event::{
Request as TypingRequest, Response as TypingResponse, Typing, Request as TypingRequest, Response as TypingResponse, Typing,
}, },
uiaa::AuthData,
}, },
assign, assign,
events::{ events::{
@ -94,7 +95,7 @@ use matrix_sdk_common::{
}, },
AnyMessageEventContent, AnyMessageEventContent,
}, },
identifiers::{EventId, RoomId, RoomIdOrAliasId, ServerName, UserId}, identifiers::{DeviceIdBox, EventId, RoomId, RoomIdOrAliasId, ServerName, UserId},
instant::{Duration, Instant}, instant::{Duration, Instant},
js_int::UInt, js_int::UInt,
locks::RwLock, locks::RwLock,
@ -106,12 +107,11 @@ use matrix_sdk_common::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
keys::{get_keys, upload_keys}, keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest},
to_device::send_event_to_device::{ to_device::send_event_to_device::{
Request as RumaToDeviceRequest, Response as ToDeviceResponse, Request as RumaToDeviceRequest, Response as ToDeviceResponse,
}, },
}, },
identifiers::DeviceIdBox,
locks::Mutex, locks::Mutex,
}; };
@ -1369,6 +1369,71 @@ impl Client {
self.send(request).await self.send(request).await
} }
/// Delete the given devices from the server.
///
/// # Arguments
///
/// * `devices` - The list of devices that should be deleted from the
/// server.
///
/// * `auth_data` - This request requires user interactive auth, the first
/// request needs to set this to `None` and will always fail with an
/// `UiaaResponse`. The response will contain information for the
/// interactive auth and the same request needs to be made but this time
/// with some `auth_data` provided.
///
/// ```no_run
/// # use matrix_sdk::{
/// # api::r0::uiaa::{UiaaResponse, AuthData},
/// # Client, SyncSettings, Error, FromHttpResponseError, ServerError,
/// # };
/// # use futures::executor::block_on;
/// # use serde_json::json;
/// # use url::Url;
/// # use std::{collections::BTreeMap, convert::TryFrom};
/// # block_on(async {
/// # let homeserver = Url::parse("http://localhost:8080").unwrap();
/// # let mut client = Client::new(homeserver).unwrap();
/// let devices = &["DEVICEID".into()];
///
/// if let Err(e) = client.delete_devices(devices, None).await {
/// if let Some(info) = e.uiaa_response() {
/// let mut auth_parameters = BTreeMap::new();
///
/// let identifier = json!({
/// "type": "m.id.user",
/// "user": "example",
/// });
/// auth_parameters.insert("identifier".to_owned(), identifier);
/// auth_parameters.insert("password".to_owned(), "wordpass".into());
///
/// // This is needed because of https://github.com/matrix-org/synapse/issues/5665
/// auth_parameters.insert("user".to_owned(), "@example:localhost".into());
///
/// let auth_data = AuthData::DirectRequest {
/// kind: "m.login.password",
/// auth_parameters,
/// session: info.session.as_deref(),
/// };
///
/// client
/// .delete_devices(devices, Some(auth_data))
/// .await
/// .expect("Can't delete devices");
/// }
/// }
/// # });
pub async fn delete_devices(
&self,
devices: &[DeviceIdBox],
auth_data: Option<AuthData<'_>>,
) -> Result<delete_devices::Response> {
let mut request = delete_devices::Request::new(devices);
request.auth = auth_data;
self.send(request).await
}
/// Synchronize the client's state with the latest state on the server. /// Synchronize the client's state with the latest state on the server.
/// ///
/// **Note**: You should not use this method to repeatedly sync if encryption /// **Note**: You should not use this method to repeatedly sync if encryption
@ -1742,6 +1807,33 @@ impl Client {
})) }))
} }
/// TODO
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn bootstrap_cross_signing(&self, auth_data: Option<AuthData<'_>>) -> Result<()> {
let olm = self
.base_client
.olm_machine()
.await
.ok_or(Error::AuthenticationRequired)?;
let (request, signature_request) = olm.bootstrap_cross_signing(false).await?;
println!("HELLOOO MAKING REQUEST {:#?}", request);
let request = UploadSigningKeysRequest {
auth: auth_data,
master_key: request.master_key,
self_signing_key: request.self_signing_key,
user_signing_key: request.user_signing_key,
};
self.send(request).await?;
self.send(signature_request).await?;
Ok(())
}
/// Get a map holding all the devices of an user. /// Get a map holding all the devices of an user.
/// ///
/// This will always return an empty map if the client hasn't been logged /// This will always return an empty map if the client hasn't been logged
@ -1800,6 +1892,8 @@ impl Client {
/// ///
/// # Panics /// # Panics
/// ///
/// This method will panic if it isn't run on a Tokio runtime.
///
/// This method will panic if it can't get enough randomness from the OS to /// This method will panic if it can't get enough randomness from the OS to
/// encrypt the exported keys securely. /// encrypt the exported keys securely.
/// ///
@ -1868,6 +1962,17 @@ impl Client {
/// Import E2EE keys from the given file path. /// Import E2EE keys from the given file path.
/// ///
/// # Arguments
///
/// * `path` - The file path where the exported key file will can be found.
///
/// * `passphrase` - The passphrase that should be used to decrypt the
/// exported room keys.
///
/// # Panics
///
/// This method will panic if it isn't run on a Tokio runtime.
///
/// ```no_run /// ```no_run
/// # use std::{path::PathBuf, time::Duration}; /// # use std::{path::PathBuf, time::Duration};
/// # use matrix_sdk::{ /// # use matrix_sdk::{
@ -1893,7 +1998,7 @@ impl Client {
feature = "docs", feature = "docs",
doc(cfg(all(encryption, not(target_arch = "wasm32")))) doc(cfg(all(encryption, not(target_arch = "wasm32"))))
)] )]
pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<()> { pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<usize> {
let olm = self let olm = self
.base_client .base_client
.olm_machine() .olm_machine()
@ -1909,9 +2014,7 @@ impl Client {
let task = tokio::task::spawn_blocking(decrypt); let task = tokio::task::spawn_blocking(decrypt);
let import = task.await.expect("Task join error").unwrap(); let import = task.await.expect("Task join error").unwrap();
olm.import_keys(import).await.unwrap(); Ok(olm.import_keys(import).await?)
Ok(())
} }
} }
@ -1941,7 +2044,10 @@ mod test {
use serde_json::json; use serde_json::json;
use tempfile::tempdir; use tempfile::tempdir;
use std::{convert::TryInto, io::Cursor, path::Path, str::FromStr, time::Duration}; use std::{
collections::BTreeMap, convert::TryInto, io::Cursor, path::Path, str::FromStr,
time::Duration,
};
async fn logged_in_client() -> Client { async fn logged_in_client() -> Client {
let session = Session { let session = Session {
@ -2740,4 +2846,61 @@ mod test {
assert_eq!("tutorial".to_string(), room.read().await.display_name()); assert_eq!("tutorial".to_string(), room.read().await.display_name());
} }
#[tokio::test]
async fn delete_devices() {
let homeserver = Url::from_str(&mockito::server_url()).unwrap();
let client = Client::new(homeserver).unwrap();
let _m = mock("POST", "/_matrix/client/r0/delete_devices")
.with_status(401)
.with_body(
json!({
"flows": [
{
"stages": [
"m.login.password"
]
}
],
"params": {},
"session": "vBslorikviAjxzYBASOBGfPp"
})
.to_string(),
)
.create();
let _m = mock("POST", "/_matrix/client/r0/delete_devices")
.with_status(401)
// empty response
// TODO rename that response type.
.with_body(test_json::LOGOUT.to_string())
.create();
let devices = &["DEVICEID".into()];
if let Err(e) = client.delete_devices(devices, None).await {
if let Some(info) = e.uiaa_response() {
let mut auth_parameters = BTreeMap::new();
let identifier = json!({
"type": "m.id.user",
"user": "example",
});
auth_parameters.insert("identifier".to_owned(), identifier);
auth_parameters.insert("password".to_owned(), "wordpass".into());
let auth_data = AuthData::DirectRequest {
kind: "m.login.password",
auth_parameters,
session: info.session.as_deref(),
};
client
.delete_devices(devices, Some(auth_data))
.await
.unwrap();
}
}
}
} }

View File

@ -19,7 +19,8 @@ use matrix_sdk_base::crypto::{
UserDevices as BaseUserDevices, UserDevices as BaseUserDevices,
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::send_event_to_device::Request as ToDeviceRequest, identifiers::DeviceId, api::r0::to_device::send_event_to_device::Request as ToDeviceRequest,
identifiers::{DeviceId, DeviceIdBox},
}; };
use crate::{error::Result, http_client::HttpClient, Sas}; use crate::{error::Result, http_client::HttpClient, Sas};
@ -114,7 +115,7 @@ impl UserDevices {
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> { pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
self.inner.keys() self.inner.keys()
} }

View File

@ -16,8 +16,11 @@
use matrix_sdk_base::Error as MatrixError; use matrix_sdk_base::Error as MatrixError;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::{r0::uiaa::UiaaResponse as UiaaError, Error as RumaClientError}, api::{
FromHttpResponseError as RumaResponseError, IntoHttpError as RumaIntoHttpError, r0::uiaa::{UiaaInfo, UiaaResponse as UiaaError},
Error as RumaClientError,
},
FromHttpResponseError as RumaResponseError, IntoHttpError as RumaIntoHttpError, ServerError,
}; };
use reqwest::Error as ReqwestError; use reqwest::Error as ReqwestError;
use serde_json::Error as JsonError; use serde_json::Error as JsonError;
@ -79,6 +82,30 @@ pub enum Error {
UiaaError(RumaResponseError<UiaaError>), UiaaError(RumaResponseError<UiaaError>),
} }
impl Error {
/// Try to destructure the error into an universal interactive auth info.
///
/// Some requests require universal interactive auth, doing such a request
/// will always fail the first time with a 401 status code, the response
/// body will contain info how the client can authenticate.
///
/// The request will need to be retried, this time containing additional
/// authentication data.
///
/// This method is an convenience method to get to the info the server
/// returned on the first, failed request.
pub fn uiaa_response(&self) -> Option<&UiaaInfo> {
if let Error::UiaaError(RumaResponseError::Http(ServerError::Known(
UiaaError::AuthResponse(i),
))) = self
{
Some(i)
} else {
None
}
}
}
impl From<RumaResponseError<UiaaError>> for Error { impl From<RumaResponseError<UiaaError>> for Error {
fn from(error: RumaResponseError<UiaaError>) -> Self { fn from(error: RumaResponseError<UiaaError>) -> Self {
Self::UiaaError(error) Self::UiaaError(error)

View File

@ -27,7 +27,7 @@ docs = ["encryption", "sqlite_cryptostore", "messages"]
async-trait = "0.1.41" async-trait = "0.1.41"
serde = "1.0.116" serde = "1.0.116"
dashmap= "*" dashmap= "*"
serde_json = "1.0.58" serde_json = "1.0.59"
zeroize = "1.1.1" zeroize = "1.1.1"
tracing = "0.1.21" tracing = "0.1.21"

View File

@ -20,7 +20,7 @@ js_int = "0.1.9"
[dependencies.ruma] [dependencies.ruma]
version = "0.0.1" version = "0.0.1"
git = "https://github.com/ruma/ruma" path = "/home/poljar/werk/priv/ruma/ruma"
rev = "50eb700571480d1440e15a387d10f98be8abab59" rev = "50eb700571480d1440e15a387d10f98be8abab59"
features = ["client-api", "unstable-pre-spec", "unstable-exhaustive-types"] features = ["client-api", "unstable-pre-spec", "unstable-exhaustive-types"]

View File

@ -28,7 +28,7 @@ matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
olm-rs = { version = "1.0.0", features = ["serde"] } olm-rs = { version = "1.0.0", features = ["serde"] }
getrandom = "0.2.0" getrandom = "0.2.0"
serde = { version = "1.0.116", features = ["derive", "rc"] } serde = { version = "1.0.116", features = ["derive", "rc"] }
serde_json = "1.0.58" serde_json = "1.0.59"
cjson = "0.1.1" cjson = "0.1.1"
zeroize = { version = "1.1.1", features = ["zeroize_derive"] } zeroize = { version = "1.1.1", features = ["zeroize_derive"] }
url = "2.1.1" url = "2.1.1"
@ -39,6 +39,7 @@ tracing = "0.1.21"
atomic = "0.5.0" atomic = "0.5.0"
dashmap = "3.11.10" dashmap = "3.11.10"
sha2 = "0.9.1" sha2 = "0.9.1"
aes-gcm = "0.7.0"
aes-ctr = "0.5.0" aes-ctr = "0.5.0"
pbkdf2 = { version = "0.5.0", default-features = false } pbkdf2 = { version = "0.5.0", default-features = false }
hmac = "0.9.0" hmac = "0.9.0"
@ -51,7 +52,8 @@ default-features = false
features = ["std", "std-future"] features = ["std", "std-future"]
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx] [target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx]
version = "0.3.5" git = "https://github.com/launchbadge/sqlx/"
rev = "fd25a7530cf087e1529553ff854f192738db3461"
optional = true optional = true
default-features = false default-features = false
features = ["runtime-tokio", "sqlite", "macros"] features = ["runtime-tokio", "sqlite", "macros"]
@ -60,7 +62,7 @@ features = ["runtime-tokio", "sqlite", "macros"]
tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] }
futures = "0.3.6" futures = "0.3.6"
proptest = "0.10.1" proptest = "0.10.1"
serde_json = "1.0.58" serde_json = "1.0.59"
tempfile = "3.1.0" tempfile = "3.1.0"
http = "0.2.1" http = "0.2.1"
matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" }

View File

@ -48,8 +48,8 @@ pub enum OlmError {
Store(#[from] CryptoStoreError), Store(#[from] CryptoStoreError),
/// The session with a device has become corrupted. /// The session with a device has become corrupted.
#[error("decryption failed likely because a Olm session was wedged")] #[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")]
SessionWedged, SessionWedged(UserId, String),
/// Encryption failed because the device does not have a valid Olm session /// Encryption failed because the device does not have a valid Olm session
/// with us. /// with us.
@ -148,6 +148,9 @@ pub enum SignatureError {
#[error("the provided JSON object can't be converted to a canonical representation")] #[error("the provided JSON object can't be converted to a canonical representation")]
CanonicalJsonError(CjsonError), CanonicalJsonError(CjsonError),
#[error(transparent)]
JsonError(#[from] SerdeError),
#[error("the signature didn't match the provided key")] #[error("the signature didn't match the provided key")]
VerificationError, VerificationError,
} }

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use std::{ use std::{
collections::BTreeMap, collections::{BTreeMap, HashMap},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
ops::Deref, ops::Deref,
sync::{ sync::{
@ -30,13 +30,19 @@ use matrix_sdk_common::{
forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent, forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent,
EventType, EventType,
}, },
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId}, identifiers::{
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId,
},
locks::Mutex,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::warn; use tracing::warn;
use crate::olm::InboundGroupSession; use crate::{
olm::{InboundGroupSession, Session},
store::{Changes, DeviceChanges},
};
#[cfg(test)] #[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount}; use crate::{OlmMachine, ReadOnlyAccount};
@ -44,7 +50,7 @@ use crate::{
error::{EventError, OlmError, OlmResult, SignatureError}, error::{EventError, OlmError, OlmResult, SignatureError},
identities::{OwnUserIdentity, UserIdentities}, identities::{OwnUserIdentity, UserIdentities},
olm::Utility, olm::Utility,
store::{caches::ReadOnlyUserDevices, CryptoStore, Result as StoreResult}, store::{CryptoStore, Result as StoreResult},
verification::VerificationMachine, verification::VerificationMachine,
Sas, ToDeviceRequest, Sas, ToDeviceRequest,
}; };
@ -89,6 +95,15 @@ impl Device {
.await .await
} }
/// Get the Olm sessions that belong to this device.
pub(crate) async fn get_sessions(&self) -> StoreResult<Option<Arc<Mutex<Vec<Session>>>>> {
if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
self.verification_machine.store.get_sessions(k).await
} else {
Ok(None)
}
}
/// Get the trust state of the device. /// Get the trust state of the device.
pub fn trust_state(&self) -> bool { pub fn trust_state(&self) -> bool {
self.inner self.inner
@ -106,10 +121,15 @@ impl Device {
pub async fn set_local_trust(&self, trust_state: LocalTrust) -> StoreResult<()> { pub async fn set_local_trust(&self, trust_state: LocalTrust) -> StoreResult<()> {
self.inner.set_trust_state(trust_state); self.inner.set_trust_state(trust_state);
self.verification_machine let changes = Changes {
.store devices: DeviceChanges {
.save_devices(&[self.inner.clone()]) changed: vec![self.inner.clone()],
.await ..Default::default()
},
..Default::default()
};
self.verification_machine.store.save_changes(changes).await
} }
/// Encrypt the given content for this `Device`. /// Encrypt the given content for this `Device`.
@ -123,7 +143,7 @@ impl Device {
&self, &self,
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner self.inner
.encrypt(&**self.verification_machine.store, event_type, content) .encrypt(&**self.verification_machine.store, event_type, content)
.await .await
@ -134,7 +154,7 @@ impl Device {
pub async fn encrypt_session( pub async fn encrypt_session(
&self, &self,
session: InboundGroupSession, session: InboundGroupSession,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<(Session, EncryptedEventContent)> {
let export = session.export().await; let export = session.export().await;
let content: ForwardedRoomKeyEventContent = if let Ok(c) = export.try_into() { let content: ForwardedRoomKeyEventContent = if let Ok(c) = export.try_into() {
@ -158,7 +178,7 @@ impl Device {
/// A read only view over all devices belonging to a user. /// A read only view over all devices belonging to a user.
#[derive(Debug)] #[derive(Debug)]
pub struct UserDevices { pub struct UserDevices {
pub(crate) inner: ReadOnlyUserDevices, pub(crate) inner: HashMap<DeviceIdBox, ReadOnlyDevice>,
pub(crate) verification_machine: VerificationMachine, pub(crate) verification_machine: VerificationMachine,
pub(crate) own_identity: Option<OwnUserIdentity>, pub(crate) own_identity: Option<OwnUserIdentity>,
pub(crate) device_owner_identity: Option<UserIdentities>, pub(crate) device_owner_identity: Option<UserIdentities>,
@ -168,7 +188,7 @@ impl UserDevices {
/// Get the specific device with the given device id. /// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<Device> { pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.inner.get(device_id).map(|d| Device { self.inner.get(device_id).map(|d| Device {
inner: d, inner: d.clone(),
verification_machine: self.verification_machine.clone(), verification_machine: self.verification_machine.clone(),
own_identity: self.own_identity.clone(), own_identity: self.own_identity.clone(),
device_owner_identity: self.device_owner_identity.clone(), device_owner_identity: self.device_owner_identity.clone(),
@ -176,13 +196,13 @@ impl UserDevices {
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> { pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
self.inner.keys() self.inner.keys()
} }
/// Iterator over all the devices of the user devices. /// Iterator over all the devices of the user devices.
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ { pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
self.inner.devices().map(move |d| Device { self.inner.values().map(move |d| Device {
inner: d.clone(), inner: d.clone(),
verification_machine: self.verification_machine.clone(), verification_machine: self.verification_machine.clone(),
own_identity: self.own_identity.clone(), own_identity: self.own_identity.clone(),
@ -352,7 +372,7 @@ impl ReadOnlyDevice {
store: &dyn CryptoStore, store: &dyn CryptoStore,
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<(Session, EncryptedEventContent)> {
let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
k k
} else { } else {
@ -384,10 +404,9 @@ impl ReadOnlyDevice {
return Err(OlmError::MissingSession); return Err(OlmError::MissingSession);
}; };
let message = session.encrypt(&self, event_type, content).await; let message = session.encrypt(&self, event_type, content).await?;
store.save_sessions(&[session]).await?;
message Ok((session, message))
} }
/// Update a device with a new device keys struct. /// Update a device with a new device keys struct.

View File

@ -27,13 +27,13 @@ use matrix_sdk_common::{
use crate::{ use crate::{
error::OlmResult, error::OlmResult,
group_manager::GroupSessionManager,
identities::{ identities::{
MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserIdentities, MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserIdentities,
UserIdentity, UserSigningPubkey, UserIdentity, UserSigningPubkey,
}, },
requests::KeysQueryRequest, requests::KeysQueryRequest,
store::{Result as StoreResult, Store}, session_manager::GroupSessionManager,
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -79,7 +79,7 @@ impl IdentityManager {
pub async fn receive_keys_query_response( pub async fn receive_keys_query_response(
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
// TODO create a enum that tells us how the device/identity changed, // TODO create a enum that tells us how the device/identity changed,
// e.g. new/deleted/display name change. // e.g. new/deleted/display name change.
// //
@ -92,9 +92,15 @@ impl IdentityManager {
let changed_devices = self let changed_devices = self
.handle_devices_from_key_query(&response.device_keys) .handle_devices_from_key_query(&response.device_keys)
.await?; .await?;
self.store.save_devices(&changed_devices).await?;
let changed_identities = self.handle_cross_singing_keys(response).await?; let changed_identities = self.handle_cross_singing_keys(response).await?;
self.store.save_user_identities(&changed_identities).await?;
let changes = Changes {
identities: changed_identities.clone(),
devices: changed_devices.clone(),
..Default::default()
};
self.store.save_changes(changes).await?;
Ok((changed_devices, changed_identities)) Ok((changed_devices, changed_identities))
} }
@ -111,9 +117,10 @@ impl IdentityManager {
async fn handle_devices_from_key_query( async fn handle_devices_from_key_query(
&self, &self,
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>, device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
) -> StoreResult<Vec<ReadOnlyDevice>> { ) -> StoreResult<DeviceChanges> {
let mut users_with_new_or_deleted_devices = HashSet::new(); let mut users_with_new_or_deleted_devices = HashSet::new();
let mut changed_devices = Vec::new();
let mut changes = DeviceChanges::default();
for (user_id, device_map) in device_keys_map { for (user_id, device_map) in device_keys_map {
// TODO move this out into the handle keys query response method // TODO move this out into the handle keys query response method
@ -137,7 +144,7 @@ impl IdentityManager {
let device = self.store.get_readonly_device(&user_id, device_id).await?; let device = self.store.get_readonly_device(&user_id, device_id).await?;
let device = if let Some(mut device) = device { if let Some(mut device) = device {
if let Err(e) = device.update_device(device_keys) { if let Err(e) = device.update_device(device_keys) {
warn!( warn!(
"Failed to update the device keys for {} {}: {:?}", "Failed to update the device keys for {} {}: {:?}",
@ -145,7 +152,7 @@ impl IdentityManager {
); );
continue; continue;
} }
device changes.changed.push(device);
} else { } else {
let device = match ReadOnlyDevice::try_from(device_keys) { let device = match ReadOnlyDevice::try_from(device_keys) {
Ok(d) => d, Ok(d) => d,
@ -159,24 +166,21 @@ impl IdentityManager {
}; };
info!("Adding a new device to the device store {:?}", device); info!("Adding a new device to the device store {:?}", device);
users_with_new_or_deleted_devices.insert(user_id); users_with_new_or_deleted_devices.insert(user_id);
device changes.new.push(device);
}; }
changed_devices.push(device);
} }
let current_devices: HashSet<&DeviceId> = let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
device_map.keys().map(|id| id.as_ref()).collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?; let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices = stored_devices_set.difference(&current_devices); let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices { for device_id in deleted_devices_set {
users_with_new_or_deleted_devices.insert(user_id); users_with_new_or_deleted_devices.insert(user_id);
if let Some(device) = stored_devices.get(device_id) { if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted(); device.mark_as_deleted();
self.store.delete_device(device).await?; changes.deleted.push(device.clone());
} }
} }
} }
@ -184,7 +188,7 @@ impl IdentityManager {
self.group_manager self.group_manager
.invalidate_sessions_new_devices(&users_with_new_or_deleted_devices); .invalidate_sessions_new_devices(&users_with_new_or_deleted_devices);
Ok(changed_devices) Ok(changes)
} }
/// Handle the device keys part of a key query response. /// Handle the device keys part of a key query response.
@ -198,8 +202,8 @@ impl IdentityManager {
async fn handle_cross_singing_keys( async fn handle_cross_singing_keys(
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> StoreResult<Vec<UserIdentities>> { ) -> StoreResult<IdentityChanges> {
let mut changed = Vec::new(); let mut changes = IdentityChanges::default();
for (user_id, master_key) in &response.master_keys { for (user_id, master_key) in &response.master_keys {
let master_key = MasterPubkey::from(master_key); let master_key = MasterPubkey::from(master_key);
@ -214,7 +218,7 @@ impl IdentityManager {
continue; continue;
}; };
let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? { let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? {
match &mut i { match &mut i {
UserIdentities::Own(ref mut identity) => { UserIdentities::Own(ref mut identity) => {
let user_signing = if let Some(s) = response.user_signing_keys.get(user_id) let user_signing = if let Some(s) = response.user_signing_keys.get(user_id)
@ -231,11 +235,11 @@ impl IdentityManager {
identity identity
.update(master_key, self_signing, user_signing) .update(master_key, self_signing, user_signing)
.map(|_| i) .map(|_| (i, false))
}
UserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| i)
} }
UserIdentities::Other(ref mut identity) => identity
.update(master_key, self_signing)
.map(|_| (i, false)),
} }
} else if user_id == self.user_id() { } else if user_id == self.user_id() {
if let Some(s) = response.user_signing_keys.get(user_id) { if let Some(s) = response.user_signing_keys.get(user_id) {
@ -253,7 +257,7 @@ impl IdentityManager {
} }
OwnUserIdentity::new(master_key, self_signing, user_signing) OwnUserIdentity::new(master_key, self_signing, user_signing)
.map(UserIdentities::Own) .map(|i| (UserIdentities::Own(i), true))
} else { } else {
warn!( warn!(
"User identity for our own user {} didn't contain a \ "User identity for our own user {} didn't contain a \
@ -269,17 +273,22 @@ impl IdentityManager {
); );
continue; continue;
} else { } else {
UserIdentity::new(master_key, self_signing).map(UserIdentities::Other) UserIdentity::new(master_key, self_signing)
.map(|i| (UserIdentities::Other(i), true))
}; };
match identity { match result {
Ok(i) => { Ok((i, new)) => {
trace!( trace!(
"Updated or created new user identity for {}: {:?}", "Updated or created new user identity for {}: {:?}",
user_id, user_id,
i i
); );
changed.push(i); if new {
changes.new.push(i);
} else {
changes.changed.push(i);
}
} }
Err(e) => { Err(e) => {
warn!( warn!(
@ -291,7 +300,7 @@ impl IdentityManager {
} }
} }
Ok(changed) Ok(changes)
} }
/// Get a key query request if one is needed. /// Get a key query request if one is needed.
@ -369,6 +378,7 @@ pub(crate) mod test {
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeyQueryResponse, api::r0::keys::get_keys::Response as KeyQueryResponse,
identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId}, identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId},
locks::Mutex,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
@ -376,10 +386,10 @@ pub(crate) mod test {
use serde_json::json; use serde_json::json;
use crate::{ use crate::{
group_manager::GroupSessionManager,
identities::IdentityManager, identities::IdentityManager,
machine::test::response_from_file, machine::test::response_from_file,
olm::{Account, ReadOnlyAccount}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
session_manager::GroupSessionManager,
store::{CryptoStore, MemoryStore, Store}, store::{CryptoStore, MemoryStore, Store},
verification::VerificationMachine, verification::VerificationMachine,
}; };
@ -401,10 +411,11 @@ pub(crate) mod test {
} }
fn manager() -> IdentityManager { fn manager() -> IdentityManager {
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id())));
let user_id = Arc::new(user_id()); let user_id = Arc::new(user_id());
let account = ReadOnlyAccount::new(&user_id, &device_id()); let account = ReadOnlyAccount::new(&user_id, &device_id());
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let verification = VerificationMachine::new(account.clone(), store); let verification = VerificationMachine::new(account.clone(), identity, store);
let store = Store::new( let store = Store::new(
user_id.clone(), user_id.clone(),
Arc::new(Box::new(MemoryStore::new())), Arc::new(Box::new(MemoryStore::new())),

View File

@ -86,6 +86,24 @@ impl From<CrossSigningKey> for UserSigningPubkey {
} }
} }
impl Into<CrossSigningKey> for MasterPubkey {
fn into(self) -> CrossSigningKey {
self.0.as_ref().clone()
}
}
impl Into<CrossSigningKey> for UserSigningPubkey {
fn into(self) -> CrossSigningKey {
self.0.as_ref().clone()
}
}
impl Into<CrossSigningKey> for SelfSigningPubkey {
fn into(self) -> CrossSigningKey {
self.0.as_ref().clone()
}
}
impl AsRef<CrossSigningKey> for MasterPubkey { impl AsRef<CrossSigningKey> for MasterPubkey {
fn as_ref(&self) -> &CrossSigningKey { fn as_ref(&self) -> &CrossSigningKey {
&self.0 &self.0
@ -135,7 +153,7 @@ impl<'a> From<&'a UserSigningPubkey> for CrossSigningSubKeys<'a> {
} }
/// Enum over the cross signing sub-keys. /// Enum over the cross signing sub-keys.
enum CrossSigningSubKeys<'a> { pub(crate) enum CrossSigningSubKeys<'a> {
/// The self signing subkey. /// The self signing subkey.
SelfSigning(&'a SelfSigningPubkey), SelfSigning(&'a SelfSigningPubkey),
/// The user signing subkey. /// The user signing subkey.
@ -152,7 +170,7 @@ impl<'a> CrossSigningSubKeys<'a> {
} }
/// Get the `CrossSigningKey` from an sub-keys enum /// Get the `CrossSigningKey` from an sub-keys enum
fn cross_signing_key(&self) -> &CrossSigningKey { pub(crate) fn cross_signing_key(&self) -> &CrossSigningKey {
match self { match self {
CrossSigningSubKeys::SelfSigning(key) => &key.0, CrossSigningSubKeys::SelfSigning(key) => &key.0,
CrossSigningSubKeys::UserSigning(key) => &key.0, CrossSigningSubKeys::UserSigning(key) => &key.0,
@ -198,7 +216,7 @@ impl MasterPubkey {
/// ///
/// Returns an empty result if the signature check succeeded, otherwise a /// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed. /// SignatureError indicating why the check failed.
fn verify_subkey<'a>( pub(crate) fn verify_subkey<'a>(
&self, &self,
subkey: impl Into<CrossSigningSubKeys<'a>>, subkey: impl Into<CrossSigningSubKeys<'a>>,
) -> Result<(), SignatureError> { ) -> Result<(), SignatureError> {
@ -666,13 +684,13 @@ pub(crate) mod test {
manager::test::{other_key_query, own_key_query}, manager::test::{other_key_query, own_key_query},
Device, ReadOnlyDevice, Device, ReadOnlyDevice,
}, },
olm::ReadOnlyAccount, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
store::MemoryStore, store::MemoryStore,
verification::VerificationMachine, verification::VerificationMachine,
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex,
}; };
use super::{OwnUserIdentity, UserIdentities, UserIdentity}; use super::{OwnUserIdentity, UserIdentities, UserIdentity};
@ -734,8 +752,12 @@ pub(crate) mod test {
assert!(identity.is_device_signed(&first).is_err()); assert!(identity.is_device_signed(&first).is_err());
assert!(identity.is_device_signed(&second).is_ok()); assert!(identity.is_device_signed(&second).is_ok());
let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
second.user_id().clone(),
)));
let verification_machine = VerificationMachine::new( let verification_machine = VerificationMachine::new(
ReadOnlyAccount::new(second.user_id(), second.device_id()), ReadOnlyAccount::new(second.user_id(), second.device_id()),
private_identity,
Arc::new(Box::new(MemoryStore::new())), Arc::new(Box::new(MemoryStore::new())),
); );

View File

@ -41,7 +41,7 @@ use matrix_sdk_common::{
use crate::{ use crate::{
error::{OlmError, OlmResult}, error::{OlmError, OlmResult},
olm::{InboundGroupSession, OutboundGroupSession}, olm::{InboundGroupSession, OutboundGroupSession, Session},
requests::{OutgoingRequest, ToDeviceRequest}, requests::{OutgoingRequest, ToDeviceRequest},
store::{CryptoStoreError, Store}, store::{CryptoStoreError, Store},
Device, Device,
@ -196,6 +196,7 @@ impl KeyRequestMachine {
device_id: Arc<DeviceIdBox>, device_id: Arc<DeviceIdBox>,
store: Store, store: Store,
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>, outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
) -> Self { ) -> Self {
Self { Self {
user_id, user_id,
@ -205,7 +206,7 @@ impl KeyRequestMachine {
outgoing_to_device_requests: Arc::new(DashMap::new()), outgoing_to_device_requests: Arc::new(DashMap::new()),
incoming_key_requests: Arc::new(DashMap::new()), incoming_key_requests: Arc::new(DashMap::new()),
wait_queue: WaitQueue::new(), wait_queue: WaitQueue::new(),
users_for_key_claim: Arc::new(DashMap::new()), users_for_key_claim,
} }
} }
@ -214,11 +215,6 @@ impl KeyRequestMachine {
&self.user_id &self.user_id
} }
/// Get the map of user/devices which we need to claim one-time for.
pub fn users_for_key_claim(&self) -> &DashMap<UserId, DashSet<DeviceIdBox>> {
&self.users_for_key_claim
}
pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> { pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
self.outgoing_to_device_requests self.outgoing_to_device_requests
@ -239,15 +235,18 @@ impl KeyRequestMachine {
/// Handle all the incoming key requests that are queued up and empty our /// Handle all the incoming key requests that are queued up and empty our
/// key request queue. /// key request queue.
pub async fn collect_incoming_key_requests(&self) -> OlmResult<()> { pub async fn collect_incoming_key_requests(&self) -> OlmResult<Vec<Session>> {
let mut changed_sessions = Vec::new();
for item in self.incoming_key_requests.iter() { for item in self.incoming_key_requests.iter() {
let event = item.value(); let event = item.value();
self.handle_key_request(event).await?; if let Some(s) = self.handle_key_request(event).await? {
changed_sessions.push(s);
}
} }
self.incoming_key_requests.clear(); self.incoming_key_requests.clear();
Ok(()) Ok(changed_sessions)
} }
/// Store the key share request for later, once we get an Olm session with /// Store the key share request for later, once we get an Olm session with
@ -298,7 +297,7 @@ impl KeyRequestMachine {
async fn handle_key_request( async fn handle_key_request(
&self, &self,
event: &ToDeviceEvent<RoomKeyRequestEventContent>, event: &ToDeviceEvent<RoomKeyRequestEventContent>,
) -> OlmResult<()> { ) -> OlmResult<Option<Session>> {
let key_info = match event.content.action { let key_info = match event.content.action {
Action::Request => { Action::Request => {
if let Some(info) = &event.content.body { if let Some(info) = &event.content.body {
@ -309,11 +308,11 @@ impl KeyRequestMachine {
action, but no key info was found", action, but no key info was found",
event.sender, event.content.requesting_device_id event.sender, event.content.requesting_device_id
); );
return Ok(()); return Ok(None);
} }
} }
// We ignore cancellations here since there's nothing to serve. // We ignore cancellations here since there's nothing to serve.
Action::CancelRequest => return Ok(()), Action::CancelRequest => return Ok(None),
}; };
let session = self let session = self
@ -332,7 +331,7 @@ impl KeyRequestMachine {
"Received a key request from {} {} for an unknown inbound group session {}.", "Received a key request from {} {} for an unknown inbound group session {}.",
&event.sender, &event.content.requesting_device_id, &key_info.session_id &event.sender, &event.content.requesting_device_id, &key_info.session_id
); );
return Ok(()); return Ok(None);
}; };
let device = self let device = self
@ -353,6 +352,8 @@ impl KeyRequestMachine {
device.device_id(), device.device_id(),
e e
); );
Ok(None)
} else { } else {
info!( info!(
"Serving a key request for {} from {} {}.", "Serving a key request for {} from {} {}.",
@ -361,9 +362,9 @@ impl KeyRequestMachine {
device.device_id() device.device_id()
); );
if let Err(e) = self.share_session(&session, &device).await { match self.share_session(&session, &device).await {
match e { Ok(s) => Ok(Some(s)),
OlmError::MissingSession => { Err(OlmError::MissingSession) => {
info!( info!(
"Key request from {} {} is missing an Olm session, \ "Key request from {} {} is missing an Olm session, \
putting the request in the wait queue", putting the request in the wait queue",
@ -371,10 +372,10 @@ impl KeyRequestMachine {
device.device_id() device.device_id()
); );
self.handle_key_share_without_session(device, event); self.handle_key_share_without_session(device, event);
return Ok(());
} Ok(None)
e => return Err(e),
} }
Err(e) => Err(e),
} }
} }
} else { } else {
@ -383,13 +384,17 @@ impl KeyRequestMachine {
&event.sender, &event.content.requesting_device_id &event.sender, &event.content.requesting_device_id
); );
self.store.update_tracked_user(&event.sender, true).await?; self.store.update_tracked_user(&event.sender, true).await?;
Ok(None)
}
} }
Ok(()) async fn share_session(
} &self,
session: &InboundGroupSession,
async fn share_session(&self, session: &InboundGroupSession, device: &Device) -> OlmResult<()> { device: &Device,
let content = device.encrypt_session(session.clone()).await?; ) -> OlmResult<Session> {
let (used_session, content) = device.encrypt_session(session.clone()).await?;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
@ -416,7 +421,7 @@ impl KeyRequestMachine {
self.outgoing_to_device_requests.insert(id, request); self.outgoing_to_device_requests.insert(id, request);
Ok(()) Ok(used_session)
} }
/// Check if it's ok to share a session with the given device. /// Check if it's ok to share a session with the given device.
@ -573,23 +578,20 @@ impl KeyRequestMachine {
Ok(()) Ok(())
} }
/// Save an inbound group session we received using a key forward. /// Mark the given outgoing key info as done.
/// ///
/// At the same time delete the key info since we received the wanted key. /// This will queue up a request cancelation.
async fn save_session( async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> {
&self,
key_info: OugoingKeyInfo,
session: InboundGroupSession,
) -> Result<(), CryptoStoreError> {
// TODO perhaps only remove the key info if the first known index is 0. // TODO perhaps only remove the key info if the first known index is 0.
trace!( trace!(
"Successfully received a forwarded room key for {:#?}", "Successfully received a forwarded room key for {:#?}",
key_info key_info
); );
self.store.save_inbound_group_sessions(&[session]).await?;
self.outgoing_to_device_requests self.outgoing_to_device_requests
.remove(&key_info.request_id); .remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction.
self.delete_key_info(&key_info).await?; self.delete_key_info(&key_info).await?;
let content = RoomKeyRequestEventContent { let content = RoomKeyRequestEventContent {
@ -613,7 +615,8 @@ impl KeyRequestMachine {
&self, &self,
sender_key: &str, sender_key: &str,
event: &mut ToDeviceEvent<ForwardedRoomKeyEventContent>, event: &mut ToDeviceEvent<ForwardedRoomKeyEventContent>,
) -> Result<Option<Raw<AnyToDeviceEvent>>, CryptoStoreError> { ) -> Result<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>), CryptoStoreError>
{
let key_info = self.get_key_info(&event.content).await?; let key_info = self.get_key_info(&event.content).await?;
if let Some(info) = key_info { if let Some(info) = key_info {
@ -630,27 +633,32 @@ impl KeyRequestMachine {
// If we have a previous session, check if we have a better version // If we have a previous session, check if we have a better version
// and store the new one if so. // and store the new one if so.
if let Some(old_session) = old_session { let session = if let Some(old_session) = old_session {
let first_old_index = old_session.first_known_index().await; let first_old_index = old_session.first_known_index().await;
let first_index = session.first_known_index().await; let first_index = session.first_known_index().await;
if first_old_index > first_index { if first_old_index > first_index {
self.save_session(info, session).await?; self.mark_as_done(info).await?;
Some(session)
} else {
None
} }
// If we didn't have a previous session, store it. // If we didn't have a previous session, store it.
} else { } else {
self.save_session(info, session).await?; self.mark_as_done(info).await?;
} Some(session)
};
Ok(Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey( Ok((
event.clone(), Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey(event.clone()))),
)))) session,
))
} else { } else {
info!( info!(
"Received a forwarded room key from {}, but no key info was found.", "Received a forwarded room key from {}, but no key info was found.",
event.sender, event.sender,
); );
Ok(None) Ok((None, None))
} }
} }
} }
@ -666,13 +674,14 @@ mod test {
AnyToDeviceEvent, ToDeviceEvent, AnyToDeviceEvent, ToDeviceEvent,
}, },
identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId}, identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId},
locks::Mutex,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use std::{convert::TryInto, sync::Arc}; use std::{convert::TryInto, sync::Arc};
use crate::{ use crate::{
identities::{LocalTrust, ReadOnlyDevice}, identities::{LocalTrust, ReadOnlyDevice},
olm::{Account, ReadOnlyAccount}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
store::{CryptoStore, MemoryStore, Store}, store::{CryptoStore, MemoryStore, Store},
verification::VerificationMachine, verification::VerificationMachine,
}; };
@ -711,7 +720,8 @@ mod test {
let user_id = Arc::new(bob_id()); let user_id = Arc::new(bob_id());
let account = ReadOnlyAccount::new(&user_id, &alice_device_id()); let account = ReadOnlyAccount::new(&user_id, &alice_device_id());
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let verification = VerificationMachine::new(account, store.clone()); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id())));
let verification = VerificationMachine::new(account, identity, store.clone());
let store = Store::new(user_id.clone(), store, verification); let store = Store::new(user_id.clone(), store, verification);
KeyRequestMachine::new( KeyRequestMachine::new(
@ -719,6 +729,7 @@ mod test {
Arc::new(bob_device_id()), Arc::new(bob_device_id()),
store, store,
Arc::new(DashMap::new()), Arc::new(DashMap::new()),
Arc::new(DashMap::new()),
) )
} }
@ -727,7 +738,8 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &alice_device_id()); let account = ReadOnlyAccount::new(&user_id, &alice_device_id());
let device = ReadOnlyDevice::from_account(&account).await; let device = ReadOnlyDevice::from_account(&account).await;
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let verification = VerificationMachine::new(account, store.clone()); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let verification = VerificationMachine::new(account, identity, store.clone());
let store = Store::new(user_id.clone(), store, verification); let store = Store::new(user_id.clone(), store, verification);
store.save_devices(&[device]).await.unwrap(); store.save_devices(&[device]).await.unwrap();
@ -736,6 +748,7 @@ mod test {
Arc::new(alice_device_id()), Arc::new(alice_device_id()),
store, store,
Arc::new(DashMap::new()), Arc::new(DashMap::new()),
Arc::new(DashMap::new()),
) )
} }
@ -833,20 +846,20 @@ mod test {
.is_none() .is_none()
); );
machine let (_, first_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event) .receive_forwarded_room_key(&session.sender_key, &mut event)
.await .await
.unwrap(); .unwrap();
let first_session = first_session.unwrap();
let first_session = machine
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(first_session.first_known_index().await, 10); assert_eq!(first_session.first_known_index().await, 10);
machine
.store
.save_inbound_group_sessions(&[first_session.clone()])
.await
.unwrap();
// Get the cancel request. // Get the cancel request.
let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let request = machine.outgoing_to_device_requests.iter().next().unwrap();
let id = request.request_id; let id = request.request_id;
@ -877,19 +890,12 @@ mod test {
content, content,
}; };
machine let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event) .receive_forwarded_room_key(&session.sender_key, &mut event)
.await .await
.unwrap(); .unwrap();
let second_session = machine assert!(second_session.is_none());
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(second_session.first_known_index().await, 10);
let export = session.export_at_index(0).await.unwrap(); let export = session.export_at_index(0).await.unwrap();
@ -900,18 +906,12 @@ mod test {
content, content,
}; };
machine let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event) .receive_forwarded_room_key(&session.sender_key, &mut event)
.await .await
.unwrap(); .unwrap();
let second_session = machine
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(second_session.first_known_index().await, 0); assert_eq!(second_session.unwrap().first_known_index().await, 0);
} }
#[async_test] #[async_test]
@ -1134,14 +1134,19 @@ mod test {
.unwrap() .unwrap()
.is_none()); .is_none());
let (decrypted, sender_key, _) = let (_, decrypted, sender_key, _) =
alice_account.decrypt_to_device_event(&event).await.unwrap(); alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() {
alice_machine let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e) .receive_forwarded_room_key(&sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }
@ -1317,14 +1322,19 @@ mod test {
.unwrap() .unwrap()
.is_none()); .is_none());
let (decrypted, sender_key, _) = let (_, decrypted, sender_key, _) =
alice_account.decrypt_to_device_event(&event).await.unwrap(); alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() {
alice_machine let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e) .receive_forwarded_room_key(&sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }

View File

@ -29,7 +29,6 @@
mod error; mod error;
mod file_encryption; mod file_encryption;
mod group_manager;
mod identities; mod identities;
mod key_request; mod key_request;
mod machine; mod machine;

View File

@ -17,6 +17,7 @@ use std::path::Path;
use std::{collections::BTreeMap, mem, sync::Arc}; use std::{collections::BTreeMap, mem, sync::Arc};
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::locks::Mutex;
use tracing::{debug, error, info, instrument, trace, warn}; use tracing::{debug, error, info, instrument, trace, warn};
use matrix_sdk_common::{ use matrix_sdk_common::{
@ -25,6 +26,7 @@ use matrix_sdk_common::{
claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
get_keys::Response as KeysQueryResponse, get_keys::Response as KeysQueryResponse,
upload_keys, upload_keys,
upload_signatures::Request as UploadSignaturesRequest,
}, },
sync::sync_events::Response as SyncResponse, sync::sync_events::Response as SyncResponse,
}, },
@ -45,17 +47,19 @@ use matrix_sdk_common::{
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
use crate::store::sqlite::SqliteStore; use crate::store::sqlite::SqliteStore;
use crate::{ use crate::{
error::{EventError, MegolmError, MegolmResult, OlmResult}, error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
group_manager::GroupSessionManager, identities::{Device, IdentityManager, UserDevices},
identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities},
key_request::KeyRequestMachine, key_request::KeyRequestMachine,
olm::{ olm::{
Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
InboundGroupSession, ReadOnlyAccount, InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session,
},
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
session_manager::{GroupSessionManager, SessionManager},
store::{
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
Store,
}, },
requests::{IncomingResponse, OutgoingRequest},
session_manager::SessionManager,
store::{CryptoStore, MemoryStore, Result as StoreResult, Store},
verification::{Sas, VerificationMachine}, verification::{Sas, VerificationMachine},
ToDeviceRequest, ToDeviceRequest,
}; };
@ -70,6 +74,11 @@ pub struct OlmMachine {
device_id: Arc<Box<DeviceId>>, device_id: Arc<Box<DeviceId>>,
/// Our underlying Olm Account holding our identity keys. /// Our underlying Olm Account holding our identity keys.
account: Account, account: Account,
/// The private part of our cross signing identity.
/// Used to sign devices and other users, might be missing if some other
/// device bootstraped cross signing or cross signing isn't bootstrapped at
/// all.
user_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
/// Store for the encryption keys. /// Store for the encryption keys.
/// Persists all the encryption keys so a client can resume the session /// Persists all the encryption keys so a client can resume the session
/// without the need to create new keys. /// without the need to create new keys.
@ -87,6 +96,7 @@ pub struct OlmMachine {
/// State machine handling public user identities and devices, keeping track /// State machine handling public user identities and devices, keeping track
/// of when a key query needs to be done and handling one. /// of when a key query needs to be done and handling one.
identity_manager: IdentityManager, identity_manager: IdentityManager,
cross_signing_request: Arc<Mutex<Option<UploadSignaturesRequest>>>,
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
@ -115,7 +125,13 @@ impl OlmMachine {
let device_id: DeviceIdBox = device_id.into(); let device_id: DeviceIdBox = device_id.into();
let account = ReadOnlyAccount::new(&user_id, &device_id); let account = ReadOnlyAccount::new(&user_id, &device_id);
OlmMachine::new_helper(user_id, device_id, store, account) OlmMachine::new_helper(
user_id,
device_id,
store,
account,
PrivateCrossSigningIdentity::empty(user_id.to_owned()),
)
} }
fn new_helper( fn new_helper(
@ -123,19 +139,25 @@ impl OlmMachine {
device_id: DeviceIdBox, device_id: DeviceIdBox,
store: Box<dyn CryptoStore>, store: Box<dyn CryptoStore>,
account: ReadOnlyAccount, account: ReadOnlyAccount,
user_identity: PrivateCrossSigningIdentity,
) -> Self { ) -> Self {
let user_id = Arc::new(user_id.clone()); let user_id = Arc::new(user_id.clone());
let user_identity = Arc::new(Mutex::new(user_identity));
let store = Arc::new(store); let store = Arc::new(store);
let verification_machine = VerificationMachine::new(account.clone(), store.clone()); let verification_machine =
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
let store = Store::new(user_id.clone(), store, verification_machine.clone()); let store = Store::new(user_id.clone(), store, verification_machine.clone());
let device_id: Arc<DeviceIdBox> = Arc::new(device_id); let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
let outbound_group_sessions = Arc::new(DashMap::new()); let outbound_group_sessions = Arc::new(DashMap::new());
let users_for_key_claim = Arc::new(DashMap::new());
let key_request_machine = KeyRequestMachine::new( let key_request_machine = KeyRequestMachine::new(
user_id.clone(), user_id.clone(),
device_id.clone(), device_id.clone(),
store.clone(), store.clone(),
outbound_group_sessions, outbound_group_sessions,
users_for_key_claim.clone(),
); );
let account = Account { let account = Account {
@ -143,8 +165,12 @@ impl OlmMachine {
store: store.clone(), store: store.clone(),
}; };
let session_manager = let session_manager = SessionManager::new(
SessionManager::new(account.clone(), key_request_machine.clone(), store.clone()); account.clone(),
users_for_key_claim,
key_request_machine.clone(),
store.clone(),
);
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
let identity_manager = IdentityManager::new( let identity_manager = IdentityManager::new(
user_id.clone(), user_id.clone(),
@ -157,12 +183,14 @@ impl OlmMachine {
user_id, user_id,
device_id, device_id,
account, account,
user_identity,
store, store,
session_manager, session_manager,
group_session_manager, group_session_manager,
verification_machine, verification_machine,
key_request_machine, key_request_machine,
identity_manager, identity_manager,
cross_signing_request: Arc::new(Mutex::new(None)),
} }
} }
@ -197,11 +225,26 @@ impl OlmMachine {
} }
None => { None => {
debug!("Creating a new account"); debug!("Creating a new account");
ReadOnlyAccount::new(&user_id, &device_id) let account = ReadOnlyAccount::new(&user_id, &device_id);
store.save_account(account.clone()).await?;
account
} }
}; };
Ok(OlmMachine::new_helper(&user_id, device_id, store, account)) let identity = match store.load_identity().await? {
Some(i) => {
debug!("Restored the cross signing identity");
i
}
None => {
debug!("Creating an empty cross signing identity stub");
PrivateCrossSigningIdentity::empty(user_id.clone())
}
};
Ok(OlmMachine::new_helper(
&user_id, device_id, store, account, identity,
))
} }
/// Create a new machine with the default crypto store. /// Create a new machine with the default crypto store.
@ -305,11 +348,58 @@ impl OlmMachine {
IncomingResponse::ToDevice(_) => { IncomingResponse::ToDevice(_) => {
self.mark_to_device_request_as_sent(&request_id).await?; self.mark_to_device_request_as_sent(&request_id).await?;
} }
IncomingResponse::SigningKeysUpload(_) => {
self.receive_cross_signing_upload_response().await?;
}
}; };
Ok(()) Ok(())
} }
/// Mark the cross signing identity as shared.
async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> {
self.user_identity.lock().await.mark_as_shared();
self.store
.save_identity((&*self.user_identity.lock().await).clone())
.await
}
/// Create a new cross signing identity and get the upload request to push
/// the new public keys to the server.
///
/// **Warning**: This will delete any existing cross signing keys that might
/// exist on the server and thus will reset the trust between all the
/// devices.
///
/// Uploading these keys will require user interactive auth.
pub async fn bootstrap_cross_signing(
&self,
reset: bool,
) -> StoreResult<(UploadSigningKeysRequest, UploadSignaturesRequest)> {
let mut identity = self.user_identity.lock().await;
if identity.is_empty().await || reset {
info!("Creating new cross signing identity");
let (id, signature_request) = self.account.bootstrap_cross_signing().await;
let request = id.as_upload_request().await;
*identity = id;
self.store.save_identity(identity.clone()).await?;
Ok((request, signature_request))
} else {
info!("Trying to upload the existing cross signing identity");
let request = identity.as_upload_request().await;
let device_keys = self.account.unsigned_device_keys();
// TODO remove this expect.
let signature_request = identity
.sign_device(device_keys)
.await
.expect("Can't sign device keys");
Ok((request, signature_request))
}
}
/// Should device or one-time keys be uploaded to the server. /// Should device or one-time keys be uploaded to the server.
/// ///
/// This needs to be checked periodically, ideally after every sync request. /// This needs to be checked periodically, ideally after every sync request.
@ -417,7 +507,7 @@ impl OlmMachine {
async fn receive_keys_query_response( async fn receive_keys_query_response(
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager self.identity_manager
.receive_keys_query_response(response) .receive_keys_query_response(response)
.await .await
@ -448,12 +538,12 @@ impl OlmMachine {
async fn decrypt_to_device_event( async fn decrypt_to_device_event(
&self, &self,
event: &ToDeviceEvent<EncryptedEventContent>, event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<Raw<AnyToDeviceEvent>> { ) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, Option<InboundGroupSession>)> {
let (decrypted_event, sender_key, signing_key) = let (session, decrypted_event, sender_key, signing_key) =
self.account.decrypt_to_device_event(event).await?; self.account.decrypt_to_device_event(event).await?;
// Handle the decrypted event, e.g. fetch out Megolm sessions out of // Handle the decrypted event, e.g. fetch out Megolm sessions out of
// the event. // the event.
if let Some(event) = self if let (Some(event), group_session) = self
.handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event) .handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event)
.await? .await?
{ {
@ -462,9 +552,9 @@ impl OlmMachine {
// don't want them to be able to do silly things with it. Handling // don't want them to be able to do silly things with it. Handling
// events modifies them and returns a modified one, so replace it // events modifies them and returns a modified one, so replace it
// here if we get one. // here if we get one.
Ok(event) Ok((session, event, group_session))
} else { } else {
Ok(decrypted_event) Ok((session, decrypted_event, None))
} }
} }
@ -474,7 +564,7 @@ impl OlmMachine {
sender_key: &str, sender_key: &str,
signing_key: &str, signing_key: &str,
event: &mut ToDeviceEvent<RoomKeyEventContent>, event: &mut ToDeviceEvent<RoomKeyEventContent>,
) -> OlmResult<Option<Raw<AnyToDeviceEvent>>> { ) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> {
match event.content.algorithm { match event.content.algorithm {
EventEncryptionAlgorithm::MegolmV1AesSha2 => { EventEncryptionAlgorithm::MegolmV1AesSha2 => {
let session_key = GroupSessionKey(mem::take(&mut event.content.session_key)); let session_key = GroupSessionKey(mem::take(&mut event.content.session_key));
@ -485,17 +575,15 @@ impl OlmMachine {
&event.content.room_id, &event.content.room_id,
session_key, session_key,
)?; )?;
let _ = self.store.save_inbound_group_sessions(&[session]).await?;
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
Ok(Some(event)) Ok((Some(event), Some(session)))
} }
_ => { _ => {
warn!( warn!(
"Received room key with unsupported key algorithm {}", "Received room key with unsupported key algorithm {}",
event.content.algorithm event.content.algorithm
); );
Ok(None) Ok((None, None))
} }
} }
} }
@ -505,9 +593,14 @@ impl OlmMachine {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> OlmResult<()> { ) -> OlmResult<()> {
self.group_session_manager let (_, session) = self
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default()) .create_outbound_group_session(room_id, EncryptionSettings::default())
.await .await?;
self.store.save_inbound_group_sessions(&[session]).await?;
Ok(())
} }
/// Encrypt a room message for the given room. /// Encrypt a room message for the given room.
@ -597,12 +690,12 @@ impl OlmMachine {
sender_key: &str, sender_key: &str,
signing_key: &str, signing_key: &str,
event: &Raw<AnyToDeviceEvent>, event: &Raw<AnyToDeviceEvent>,
) -> OlmResult<Option<Raw<AnyToDeviceEvent>>> { ) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> {
let event = if let Ok(e) = event.deserialize() { let event = if let Ok(e) = event.deserialize() {
e e
} else { } else {
warn!("Decrypted to-device event failed to be parsed correctly"); warn!("Decrypted to-device event failed to be parsed correctly");
return Ok(None); return Ok((None, None));
}; };
match event { match event {
@ -615,7 +708,7 @@ impl OlmMachine {
.await?), .await?),
_ => { _ => {
warn!("Received a unexpected encrypted to-device event"); warn!("Received a unexpected encrypted to-device event");
Ok(None) Ok((None, None))
} }
} }
} }
@ -638,6 +731,8 @@ impl OlmMachine {
.mark_outgoing_request_as_sent(request_id) .mark_outgoing_request_as_sent(request_id)
.await?; .await?;
self.group_session_manager.mark_request_as_sent(request_id); self.group_session_manager.mark_request_as_sent(request_id);
self.session_manager
.mark_outgoing_request_as_sent(request_id);
Ok(()) Ok(())
} }
@ -647,11 +742,8 @@ impl OlmMachine {
self.verification_machine.get_sas(flow_id) self.verification_machine.get_sas(flow_id)
} }
async fn update_one_time_key_count( async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
&self, self.account.update_uploaded_key_count(key_count).await;
key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>,
) -> StoreResult<()> {
self.account.update_uploaded_key_count(key_count).await
} }
/// Handle a sync response and update the internal state of the Olm machine. /// Handle a sync response and update the internal state of the Olm machine.
@ -667,15 +759,19 @@ impl OlmMachine {
/// ///
/// [`decrypt_room_event`]: #method.decrypt_room_event /// [`decrypt_room_event`]: #method.decrypt_room_event
#[instrument(skip(response))] #[instrument(skip(response))]
pub async fn receive_sync_response(&self, response: &mut SyncResponse) { pub async fn receive_sync_response(&self, response: &mut SyncResponse) -> OlmResult<()> {
// Remove verification objects that have expired or are done.
self.verification_machine.garbage_collect(); self.verification_machine.garbage_collect();
if let Err(e) = self // Always save the account, a new session might get created which also
.update_one_time_key_count(&response.device_one_time_keys_count) // touches the account.
.await let mut changes = Changes {
{ account: Some(self.account.inner.clone()),
error!("Error updating the one-time key count {:?}", e); ..Default::default()
} };
self.update_one_time_key_count(&response.device_one_time_keys_count)
.await;
for user_id in &response.device_lists.changed { for user_id in &response.device_lists.changed {
if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await { if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await {
@ -696,19 +792,37 @@ impl OlmMachine {
match &mut event { match &mut event {
AnyToDeviceEvent::RoomEncrypted(e) => { AnyToDeviceEvent::RoomEncrypted(e) => {
let decrypted_event = match self.decrypt_to_device_event(e).await { let (session, decrypted_event, group_session) =
match self.decrypt_to_device_event(e).await {
Ok(e) => e, Ok(e) => e,
Err(err) => { Err(err) => {
warn!( warn!(
"Failed to decrypt to-device event from {} {}", "Failed to decrypt to-device event from {} {}",
e.sender, err e.sender, err
); );
// TODO if the session is wedged mark it for
// unwedging. if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self
.session_manager
.mark_device_as_wedged(&sender, &curve_key)
.await
{
error!(
"Couldn't mark device from {} to be unwedged {:?}",
sender, e
);
}
}
continue; continue;
} }
}; };
changes.sessions.push(session);
if let Some(group_session) = group_session {
changes.inbound_group_sessions.push(group_session);
}
*event_result = decrypted_event; *event_result = decrypted_event;
} }
AnyToDeviceEvent::RoomKeyRequest(e) => { AnyToDeviceEvent::RoomKeyRequest(e) => {
@ -726,13 +840,14 @@ impl OlmMachine {
} }
} }
if let Err(e) = self let changed_sessions = self
.key_request_machine .key_request_machine
.collect_incoming_key_requests() .collect_incoming_key_requests()
.await .await?;
{
error!("Error collecting our key share requests {:?}", e); changes.sessions.extend(changed_sessions);
}
Ok(self.store.save_changes(changes).await?)
} }
/// Decrypt an event from a room timeline. /// Decrypt an event from a room timeline.
@ -887,6 +1002,7 @@ impl OlmMachine {
// Only import the session if we didn't have this session or if it's // Only import the session if we didn't have this session or if it's
// a better version of the same session, that is the first known // a better version of the same session, that is the first known
// index is lower. // index is lower.
// TODO load all sessions so we don't do a thousand small loads.
if let Some(existing_session) = self if let Some(existing_session) = self
.store .store
.get_inbound_group_session( .get_inbound_group_session(
@ -909,7 +1025,17 @@ impl OlmMachine {
let num_sessions = sessions.len(); let num_sessions = sessions.len();
self.store.save_inbound_group_sessions(&sessions).await?; let changes = Changes {
inbound_group_sessions: sessions,
..Default::default()
};
self.store.save_changes(changes).await?;
info!(
"Successfully imported {} inbound group sessions",
num_sessions
);
Ok(num_sessions) Ok(num_sessions)
} }
@ -1130,15 +1256,19 @@ pub(crate) mod test {
.unwrap() .unwrap()
.unwrap(); .unwrap();
let event = ToDeviceEvent { let (session, content) = bob_device
sender: alice.user_id().clone(),
content: bob_device
.encrypt(EventType::Dummy, json!({})) .encrypt(EventType::Dummy, json!({}))
.await .await
.unwrap(), .unwrap();
alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent {
sender: alice.user_id().clone(),
content,
}; };
bob.decrypt_to_device_event(&event).await.unwrap(); let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store.save_sessions(&[session]).await.unwrap();
(alice, bob) (alice, bob)
} }
@ -1424,13 +1554,15 @@ pub(crate) mod test {
content: bob_device content: bob_device
.encrypt(EventType::Dummy, json!({})) .encrypt(EventType::Dummy, json!({}))
.await .await
.unwrap(), .unwrap()
.1,
}; };
let event = bob let event = bob
.decrypt_to_device_event(&event) .decrypt_to_device_event(&event)
.await .await
.unwrap() .unwrap()
.1
.deserialize() .deserialize()
.unwrap(); .unwrap();
@ -1466,12 +1598,14 @@ pub(crate) mod test {
.get_outbound_group_session(&room_id) .get_outbound_group_session(&room_id)
.unwrap(); .unwrap();
let event = bob let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap();
.decrypt_to_device_event(&event)
bob.store.save_sessions(&[session]).await.unwrap();
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await .await
.unwrap()
.deserialize()
.unwrap(); .unwrap();
let event = event.deserialize().unwrap();
if let AnyToDeviceEvent::RoomKey(event) = event { if let AnyToDeviceEvent::RoomKey(event) = event {
assert_eq!(&event.sender, alice.user_id()); assert_eq!(&event.sender, alice.user_id());
@ -1511,7 +1645,11 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests), content: to_device_requests_to_content(to_device_requests),
}; };
bob.decrypt_to_device_event(&event).await.unwrap(); let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap();
let plaintext = "It is a secret to everybody"; let plaintext = "It is a secret to everybody";
@ -1557,7 +1695,7 @@ pub(crate) mod test {
} }
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
async fn test_machine_with_default_store() { async fn test_machine_with_default_store() {
let tmpdir = tempdir().unwrap(); let tmpdir = tempdir().unwrap();

View File

@ -30,7 +30,9 @@ use tracing::{debug, trace, warn};
#[cfg(test)] #[cfg(test)]
use matrix_sdk_common::events::EventType; use matrix_sdk_common::events::EventType;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::{upload_keys, OneTimeKey, SignedKey}, api::r0::keys::{
upload_keys, upload_signatures::Request as SignatureUploadRequest, OneTimeKey, SignedKey,
},
encryption::DeviceKeys, encryption::DeviceKeys,
events::{room::encrypted::EncryptedEventContent, AnyToDeviceEvent}, events::{room::encrypted::EncryptedEventContent, AnyToDeviceEvent},
identifiers::{ identifiers::{
@ -50,13 +52,16 @@ use olm_rs::{
}; };
use crate::{ use crate::{
error::{EventError, OlmResult, SessionCreationError}, error::{EventError, OlmResult, SessionCreationError, SignatureError},
identities::ReadOnlyDevice, identities::ReadOnlyDevice,
store::{Result as StoreResult, Store}, store::Store,
OlmError, OlmError,
}; };
use super::{EncryptionSettings, InboundGroupSession, OutboundGroupSession, Session}; use super::{
EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity,
Session,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Account { pub struct Account {
@ -76,7 +81,7 @@ impl Account {
pub async fn decrypt_to_device_event( pub async fn decrypt_to_device_event(
&self, &self,
event: &ToDeviceEvent<EncryptedEventContent>, event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<(Raw<AnyToDeviceEvent>, String, String)> { ) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String, String)> {
debug!("Decrypting to-device event"); debug!("Decrypting to-device event");
let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content {
@ -103,27 +108,28 @@ impl Account {
.map_err(|_| EventError::UnsupportedOlmType)?; .map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it. // Decrypt the OlmMessage and get a Ruma event out of it.
let (decrypted_event, signing_key) = self let (session, decrypted_event, signing_key) = self
.decrypt_olm_message(&event.sender, &content.sender_key, message) .decrypt_olm_message(&event.sender, &content.sender_key, message)
.await?; .await?;
debug!("Decrypted a to-device event {:?}", decrypted_event); debug!("Decrypted a to-device event {:?}", decrypted_event);
Ok((decrypted_event, content.sender_key.clone(), signing_key)) Ok((
session,
decrypted_event,
content.sender_key.clone(),
signing_key,
))
} else { } else {
warn!("Olm event doesn't contain a ciphertext for our key"); warn!("Olm event doesn't contain a ciphertext for our key");
Err(EventError::MissingCiphertext.into()) Err(EventError::MissingCiphertext.into())
} }
} }
pub async fn update_uploaded_key_count( pub async fn update_uploaded_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
&self,
key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>,
) -> StoreResult<()> {
let one_time_key_count = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519); let one_time_key_count = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
self.inner.update_uploaded_key_count(count); self.inner.update_uploaded_key_count(count);
self.store.save_account(self.inner.clone()).await
} }
pub async fn receive_keys_upload_response( pub async fn receive_keys_upload_response(
@ -161,7 +167,7 @@ impl Account {
sender: &UserId, sender: &UserId,
sender_key: &str, sender_key: &str,
message: &OlmMessage, message: &OlmMessage,
) -> OlmResult<Option<String>> { ) -> OlmResult<Option<(Session, String)>> {
let s = self.store.get_sessions(sender_key).await?; let s = self.store.get_sessions(sender_key).await?;
// We don't have any existing sessions, return early. // We don't have any existing sessions, return early.
@ -171,8 +177,7 @@ impl Account {
return Ok(None); return Ok(None);
}; };
let mut session_to_save = None; let mut decrypted: Option<(Session, String)> = None;
let mut plaintext = None;
for session in &mut *sessions.lock().await { for session in &mut *sessions.lock().await {
let mut matches = false; let mut matches = false;
@ -191,9 +196,7 @@ impl Account {
match ret { match ret {
Ok(p) => { Ok(p) => {
plaintext = Some(p); decrypted = Some((session.clone(), p));
session_to_save = Some(session.clone());
break; break;
} }
Err(e) => { Err(e) => {
@ -205,20 +208,16 @@ impl Account {
for sender {} and sender_key {} {:?}", for sender {} and sender_key {} {:?}",
sender, sender_key, e sender, sender_key, e
); );
return Err(OlmError::SessionWedged); return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
} }
} }
} }
} }
if let Some(session) = session_to_save { Ok(decrypted)
// Decryption was successful, save the new ratchet state of the
// session that was used to decrypt the message.
trace!("Saved the new session state for {}", sender);
self.store.save_sessions(&[session]).await?;
}
Ok(plaintext)
} }
/// Decrypt an Olm message, creating a new Olm session if possible. /// Decrypt an Olm message, creating a new Olm session if possible.
@ -227,15 +226,15 @@ impl Account {
sender: &UserId, sender: &UserId,
sender_key: &str, sender_key: &str,
message: OlmMessage, message: OlmMessage,
) -> OlmResult<(Raw<AnyToDeviceEvent>, String)> { ) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session. // First try to decrypt using an existing session.
let plaintext = if let Some(p) = self let (session, plaintext) = if let Some(d) = self
.try_decrypt_olm_message(sender, sender_key, &message) .try_decrypt_olm_message(sender, sender_key, &message)
.await? .await?
{ {
// Decryption succeeded, de-structure the plaintext out of the // Decryption succeeded, de-structure the session/plaintext out of
// Option. // the Option.
p d
} else { } else {
// Decryption failed with every known session, let's try to create a // Decryption failed with every known session, let's try to create a
// new session. // new session.
@ -248,7 +247,10 @@ impl Account {
available sessions {} {}", available sessions {} {}",
sender, sender_key sender, sender_key
); );
return Err(OlmError::SessionWedged); return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
} }
OlmMessage::PreKey(m) => { OlmMessage::PreKey(m) => {
@ -265,13 +267,13 @@ impl Account {
from a prekey message: {}", from a prekey message: {}",
sender, sender_key, e sender, sender_key, e
); );
return Err(OlmError::SessionWedged); return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
} }
}; };
// Save the account since we remove the one-time key that
// was used to create this session.
self.store.save_account(self.inner.clone()).await?;
session session
} }
}; };
@ -279,15 +281,23 @@ impl Account {
// Decrypt our message, this shouldn't fail since we're using a // Decrypt our message, this shouldn't fail since we're using a
// newly created Session. // newly created Session.
let plaintext = session.decrypt(message).await?; let plaintext = session.decrypt(message).await?;
(session, plaintext)
// Save the new ratcheted state of the session.
self.store.save_sessions(&[session]).await?;
plaintext
}; };
trace!("Successfully decrypted a Olm message: {}", plaintext); trace!("Successfully decrypted a Olm message: {}", plaintext);
self.parse_decrypted_to_device_event(sender, &plaintext) let (event, signing_key) = match self.parse_decrypted_to_device_event(sender, &plaintext) {
Ok(r) => r,
Err(e) => {
// We might created a new session but decryption might still
// have failed, store it for the error case here, this is fine
// since we don't expect this to happen often or at all.
self.store.save_sessions(&[session]).await?;
return Err(e);
}
};
Ok((session, event, signing_key))
} }
/// Parse a decrypted Olm message, check that the plaintext and encrypted /// Parse a decrypted Olm message, check that the plaintext and encrypted
@ -613,9 +623,7 @@ impl ReadOnlyAccount {
}) })
} }
/// Sign the device keys of the account and return them so they can be pub(crate) fn unsigned_device_keys(&self) -> DeviceKeys {
/// uploaded.
pub(crate) async fn device_keys(&self) -> DeviceKeys {
let identity_keys = self.identity_keys(); let identity_keys = self.identity_keys();
let mut keys = BTreeMap::new(); let mut keys = BTreeMap::new();
@ -629,34 +637,41 @@ impl ReadOnlyAccount {
identity_keys.ed25519().to_owned(), identity_keys.ed25519().to_owned(),
); );
let device_keys = json!({
"user_id": (*self.user_id).clone(),
"device_id": (*self.device_id).clone(),
"algorithms": Self::ALGORITHMS,
"keys": keys,
});
let mut signatures = BTreeMap::new();
let mut signature = BTreeMap::new();
signature.insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(&device_keys).await,
);
signatures.insert((*self.user_id).clone(), signature);
DeviceKeys::new( DeviceKeys::new(
(*self.user_id).clone(), (*self.user_id).clone(),
(*self.device_id).clone(), (*self.device_id).clone(),
vec![ Self::ALGORITHMS.iter().map(|a| (&**a).clone()).collect(),
EventEncryptionAlgorithm::OlmV1Curve25519AesSha2,
EventEncryptionAlgorithm::MegolmV1AesSha2,
],
keys, keys,
signatures, BTreeMap::new(),
) )
} }
/// Sign the device keys of the account and return them so they can be
/// uploaded.
pub(crate) async fn device_keys(&self) -> DeviceKeys {
let mut device_keys = self.unsigned_device_keys();
let jsond_device_keys = serde_json::to_value(&device_keys).unwrap();
device_keys
.signatures
.entry(self.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(jsond_device_keys)
.await
.expect("Can't sign own device keys"),
);
device_keys
}
pub(crate) async fn bootstrap_cross_signing(
&self,
) -> (PrivateCrossSigningIdentity, SignatureUploadRequest) {
PrivateCrossSigningIdentity::new_with_account(self).await
}
/// Convert a JSON value to the canonical representation and sign the JSON /// Convert a JSON value to the canonical representation and sign the JSON
/// string. /// string.
/// ///
@ -668,20 +683,18 @@ impl ReadOnlyAccount {
/// # Panic /// # Panic
/// ///
/// Panics if the json value can't be serialized. /// Panics if the json value can't be serialized.
pub async fn sign_json(&self, json: &Value) -> String { pub async fn sign_json(&self, mut json: Value) -> Result<String, SignatureError> {
let canonical_json = cjson::to_string(json) let json_object = json.as_object_mut().ok_or(SignatureError::NotAnObject)?;
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); let _ = json_object.remove("unsigned");
self.sign(&canonical_json).await let _ = json_object.remove("signatures");
let canonical_json = cjson::to_string(&json)?;
Ok(self.sign(&canonical_json).await)
} }
/// Generate, sign and prepare one-time keys to be uploaded. pub(crate) async fn signed_one_time_keys_helper(
///
/// If no one-time keys need to be uploaded returns an empty error.
pub(crate) async fn signed_one_time_keys(
&self, &self,
) -> Result<BTreeMap<DeviceKeyId, OneTimeKey>, ()> { ) -> Result<BTreeMap<DeviceKeyId, OneTimeKey>, ()> {
let _ = self.generate_one_time_keys().await?;
let one_time_keys = self.one_time_keys().await; let one_time_keys = self.one_time_keys().await;
let mut one_time_key_map = BTreeMap::new(); let mut one_time_key_map = BTreeMap::new();
@ -690,7 +703,10 @@ impl ReadOnlyAccount {
"key": key, "key": key,
}); });
let signature = self.sign_json(&key_json).await; let signature = self
.sign_json(key_json)
.await
.expect("Can't sign own one-time keys");
let mut signature_map = BTreeMap::new(); let mut signature_map = BTreeMap::new();
@ -719,6 +735,16 @@ impl ReadOnlyAccount {
Ok(one_time_key_map) Ok(one_time_key_map)
} }
/// Generate, sign and prepare one-time keys to be uploaded.
///
/// If no one-time keys need to be uploaded returns an empty error.
pub(crate) async fn signed_one_time_keys(
&self,
) -> Result<BTreeMap<DeviceKeyId, OneTimeKey>, ()> {
let _ = self.generate_one_time_keys().await?;
self.signed_one_time_keys_helper().await
}
/// Create a new session with another account given a one-time key. /// Create a new session with another account given a one-time key.
/// ///
/// Returns the newly created session or a `OlmSessionError` if creating a /// Returns the newly created session or a `OlmSessionError` if creating a

View File

@ -20,6 +20,7 @@
mod account; mod account;
mod group_sessions; mod group_sessions;
mod session; mod session;
mod signing;
mod utility; mod utility;
pub(crate) use account::Account; pub(crate) use account::Account;
@ -31,6 +32,7 @@ pub use group_sessions::{
pub(crate) use group_sessions::{GroupSessionKey, OutboundGroupSession}; pub(crate) use group_sessions::{GroupSessionKey, OutboundGroupSession};
pub use olm_rs::{account::IdentityKeys, PicklingMode}; pub use olm_rs::{account::IdentityKeys, PicklingMode};
pub use session::{PickledSession, Session, SessionPickle}; pub use session::{PickledSession, Session, SessionPickle};
pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity};
pub(crate) use utility::Utility; pub(crate) use utility::Utility;
#[cfg(test)] #[cfg(test)]

View File

@ -0,0 +1,772 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(dead_code, missing_docs)]
use aes_gcm::{
aead::{generic_array::GenericArray, Aead, NewAead},
Aes256Gcm,
};
use base64::{decode_config, encode_config, DecodeError, URL_SAFE_NO_PAD};
use getrandom::getrandom;
use matrix_sdk_common::{
encryption::DeviceKeys,
identifiers::{DeviceKeyAlgorithm, DeviceKeyId},
};
use serde::{Deserialize, Serialize};
use serde_json::{Error as JsonError, Value};
use std::{
collections::BTreeMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use thiserror::Error;
use zeroize::Zeroizing;
use olm_rs::{errors::OlmUtilityError, pk::OlmPkSigning, utility::OlmUtility};
use matrix_sdk_common::{
api::r0::keys::{
upload_signatures::Request as SignatureUploadRequest, CrossSigningKey, KeyUsage,
},
identifiers::UserId,
locks::Mutex,
};
use crate::{
error::SignatureError,
identities::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey},
requests::UploadSigningKeysRequest,
UserIdentity,
};
use crate::ReadOnlyAccount;
const NONCE_SIZE: usize = 12;
fn encode<T: AsRef<[u8]>>(input: T) -> String {
encode_config(input, URL_SAFE_NO_PAD)
}
fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> {
decode_config(input, URL_SAFE_NO_PAD)
}
/// Error type reporting failures in the Signign operations.
#[derive(Debug, Error)]
pub enum SigningError {
/// Error decoding the base64 encoded pickle data.
#[error(transparent)]
Decode(#[from] DecodeError),
/// Error decrypting the pickled signing seed
#[error("Error decrypting the pickled signign seed")]
Decryption(String),
/// Error deserializing the pickle data.
#[error(transparent)]
Json(#[from] JsonError),
}
#[derive(Clone)]
pub struct Signing {
inner: Arc<Mutex<OlmPkSigning>>,
seed: Arc<Zeroizing<Vec<u8>>>,
public_key: PublicSigningKey,
}
impl std::fmt::Debug for Signing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Signing")
.field("public_key", &self.public_key.as_str())
.finish()
}
}
impl PartialEq for Signing {
fn eq(&self, other: &Signing) -> bool {
self.seed == other.seed
}
}
#[derive(Clone, PartialEq, Debug)]
struct MasterSigning {
inner: Signing,
public_key: MasterPubkey,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct PickledMasterSigning {
pickle: PickledSigning,
public_key: CrossSigningKey,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct PickledUserSigning {
pickle: PickledSigning,
public_key: CrossSigningKey,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct PickledSelfSigning {
pickle: PickledSigning,
public_key: CrossSigningKey,
}
impl MasterSigning {
async fn pickle(&self, pickle_key: &[u8]) -> PickledMasterSigning {
let pickle = self.inner.pickle(pickle_key).await;
let public_key = self.public_key.clone().into();
PickledMasterSigning { pickle, public_key }
}
fn from_pickle(pickle: PickledMasterSigning, pickle_key: &[u8]) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
}
async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
// TODO create a borrowed version of a cross singing key.
let subkey_wihtout_signatures = CrossSigningKey {
user_id: subkey.user_id.clone(),
keys: subkey.keys.clone(),
usage: subkey.usage.clone(),
signatures: BTreeMap::new(),
};
let message = cjson::to_string(&subkey_wihtout_signatures)
.expect("Can't serialize cross signing subkey");
let signature = self.inner.sign(&message).await;
subkey
.signatures
.entry(self.public_key.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
format!("ed25519:{}", self.inner.public_key().as_str()),
signature.0,
);
}
}
impl UserSigning {
async fn pickle(&self, pickle_key: &[u8]) -> PickledUserSigning {
let pickle = self.inner.pickle(pickle_key).await;
let public_key = self.public_key.clone().into();
PickledUserSigning { pickle, public_key }
}
async fn sign_user(&self, _: &UserIdentity) -> BTreeMap<UserId, BTreeMap<String, Value>> {
todo!();
}
fn from_pickle(pickle: PickledUserSigning, pickle_key: &[u8]) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
}
}
impl SelfSigning {
async fn pickle(&self, pickle_key: &[u8]) -> PickledSelfSigning {
let pickle = self.inner.pickle(pickle_key).await;
let public_key = self.public_key.clone().into();
PickledSelfSigning { pickle, public_key }
}
async fn sign_device_raw(&self, value: Value) -> Result<Signature, SignatureError> {
self.inner.sign_json(value).await
}
async fn sign_device(&self, device_keys: &mut DeviceKeys) -> Result<(), SignatureError> {
let json_device = serde_json::to_value(&device_keys)?;
let signature = self.sign_device_raw(json_device).await?;
device_keys
.signatures
.entry(self.public_key.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceKeyId::from_parts(
DeviceKeyAlgorithm::Ed25519,
self.inner.public_key.as_str().into(),
),
signature.0,
);
Ok(())
}
fn from_pickle(pickle: PickledSelfSigning, pickle_key: &[u8]) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
}
}
#[derive(Clone, PartialEq, Debug)]
struct SelfSigning {
inner: Signing,
public_key: SelfSigningPubkey,
}
#[derive(Clone, PartialEq, Debug)]
struct UserSigning {
inner: Signing,
public_key: UserSigningPubkey,
}
#[derive(Clone, Debug)]
pub struct PrivateCrossSigningIdentity {
user_id: Arc<UserId>,
shared: Arc<AtomicBool>,
master_key: Arc<Mutex<Option<MasterSigning>>>,
user_signing_key: Arc<Mutex<Option<UserSigning>>>,
self_signing_key: Arc<Mutex<Option<SelfSigning>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PickledCrossSigningIdentity {
pub user_id: UserId,
pub shared: bool,
pub pickle: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PickledSignings {
master_key: Option<PickledMasterSigning>,
user_signing_key: Option<PickledUserSigning>,
self_signing_key: Option<PickledSelfSigning>,
}
#[derive(Debug, Clone)]
pub struct Signature(String);
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PickledSigning(String);
#[derive(Debug, Clone, Serialize, Deserialize)]
struct InnerPickle {
version: u8,
nonce: String,
ciphertext: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct PublicSigningKey(Arc<String>);
impl Signature {
fn as_str(&self) -> &str {
&self.0
}
}
impl PickledSigning {
fn as_str(&self) -> &str {
&self.0
}
}
impl PublicSigningKey {
fn as_str(&self) -> &str {
&self.0
}
#[allow(clippy::inherent_to_string)]
fn to_string(&self) -> String {
self.0.to_string()
}
}
impl Signing {
fn new() -> Self {
let seed = OlmPkSigning::generate_seed();
Self::from_seed(seed)
}
fn from_seed(seed: Vec<u8>) -> Self {
let inner = OlmPkSigning::new(seed.clone()).expect("Unable to create pk signing object");
let public_key = PublicSigningKey(Arc::new(inner.public_key().to_owned()));
Signing {
inner: Arc::new(Mutex::new(inner)),
seed: Arc::new(Zeroizing::from(seed)),
public_key,
}
}
fn from_pickle(pickle: PickledSigning, pickle_key: &[u8]) -> Result<Self, SigningError> {
let pickled: InnerPickle = serde_json::from_str(pickle.as_str())?;
let key = GenericArray::from_slice(pickle_key);
let cipher = Aes256Gcm::new(key);
let nonce = decode(pickled.nonce)?;
let nonce = GenericArray::from_slice(&nonce);
let ciphertext = &decode(pickled.ciphertext)?;
let seed = cipher
.decrypt(&nonce, ciphertext.as_slice())
.map_err(|e| SigningError::Decryption(e.to_string()))?;
Ok(Self::from_seed(seed))
}
async fn pickle(&self, pickle_key: &[u8]) -> PickledSigning {
let key = GenericArray::from_slice(pickle_key);
let cipher = Aes256Gcm::new(key);
let mut nonce = vec![0u8; NONCE_SIZE];
getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object");
let nonce = GenericArray::from_slice(nonce.as_slice());
let ciphertext = cipher
.encrypt(nonce, self.seed.as_slice())
.expect("Can't encrypt signing pickle");
let ciphertext = encode(ciphertext);
let pickle = InnerPickle {
version: 1,
nonce: encode(nonce.as_slice()),
ciphertext,
};
PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
}
fn public_key(&self) -> &PublicSigningKey {
&self.public_key
}
fn cross_signing_key(&self, user_id: UserId, usage: KeyUsage) -> CrossSigningKey {
let mut keys = BTreeMap::new();
keys.insert(
format!("ed25519:{}", self.public_key().as_str()),
self.public_key().to_string(),
);
CrossSigningKey {
user_id,
usage: vec![usage],
keys,
signatures: BTreeMap::new(),
}
}
async fn verify(&self, message: &str, signature: &Signature) -> Result<bool, OlmUtilityError> {
let utility = OlmUtility::new();
utility.ed25519_verify(self.public_key.as_str(), message, signature.as_str())
}
async fn sign_json(&self, mut json: Value) -> Result<Signature, SignatureError> {
let json_object = json.as_object_mut().ok_or(SignatureError::NotAnObject)?;
let _ = json_object.remove("signatures");
let canonical_json = cjson::to_string(json_object)?;
Ok(self.sign(&canonical_json).await)
}
async fn sign(&self, message: &str) -> Signature {
Signature(self.inner.lock().await.sign(message))
}
}
impl PrivateCrossSigningIdentity {
pub fn user_id(&self) -> &UserId {
&self.user_id
}
pub async fn is_empty(&self) -> bool {
let has_master = self.master_key.lock().await.is_some();
let has_user = self.user_signing_key.lock().await.is_some();
let has_self = self.self_signing_key.lock().await.is_some();
!(has_master && has_user && has_self)
}
pub(crate) fn empty(user_id: UserId) -> Self {
Self {
user_id: Arc::new(user_id),
shared: Arc::new(AtomicBool::new(false)),
master_key: Arc::new(Mutex::new(None)),
self_signing_key: Arc::new(Mutex::new(None)),
user_signing_key: Arc::new(Mutex::new(None)),
}
}
pub(crate) async fn sign_device(
&self,
mut device_keys: DeviceKeys,
) -> Result<SignatureUploadRequest, SignatureError> {
self.self_signing_key
.lock()
.await
.as_ref()
.ok_or(SignatureError::MissingSigningKey)?
.sign_device(&mut device_keys)
.await?;
let mut signed_keys = BTreeMap::new();
signed_keys
.entry((&*self.user_id).to_owned())
.or_insert_with(BTreeMap::new)
.insert(
device_keys.device_id.to_string(),
serde_json::to_value(device_keys)?,
);
Ok(SignatureUploadRequest { signed_keys })
}
pub(crate) async fn new_with_account(
account: &ReadOnlyAccount,
) -> (Self, SignatureUploadRequest) {
let master = Signing::new();
let mut public_key =
master.cross_signing_key(account.user_id().to_owned(), KeyUsage::Master);
let signature = account
.sign_json(
serde_json::to_value(&public_key)
.expect("Can't convert own public master key to json"),
)
.await
.expect("Can't sign own public master key");
public_key
.signatures
.entry(account.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(format!("ed25519:{}", account.device_id()), signature);
let master = MasterSigning {
inner: master,
public_key: public_key.into(),
};
let identity = Self::new_helper(account.user_id(), master).await;
let device_keys = account.unsigned_device_keys();
let request = identity
.sign_device(device_keys)
.await
.expect("Can't sign own device with new cross signign keys");
(identity, request)
}
async fn new_helper(user_id: &UserId, master: MasterSigning) -> Self {
let user = Signing::new();
let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning);
master.sign_subkey(&mut public_key).await;
let user = UserSigning {
inner: user,
public_key: public_key.into(),
};
let self_signing = Signing::new();
let mut public_key =
self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning);
master.sign_subkey(&mut public_key).await;
let self_signing = SelfSigning {
inner: self_signing,
public_key: public_key.into(),
};
Self {
user_id: Arc::new(user_id.to_owned()),
shared: Arc::new(AtomicBool::new(false)),
master_key: Arc::new(Mutex::new(Some(master))),
self_signing_key: Arc::new(Mutex::new(Some(self_signing))),
user_signing_key: Arc::new(Mutex::new(Some(user))),
}
}
pub(crate) async fn new(user_id: UserId) -> Self {
let master = Signing::new();
let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master);
let master = MasterSigning {
inner: master,
public_key: public_key.into(),
};
let user = Signing::new();
let mut public_key = user.cross_signing_key(user_id.clone(), KeyUsage::UserSigning);
master.sign_subkey(&mut public_key).await;
let user = UserSigning {
inner: user,
public_key: public_key.into(),
};
let self_signing = Signing::new();
let mut public_key = self_signing.cross_signing_key(user_id.clone(), KeyUsage::SelfSigning);
master.sign_subkey(&mut public_key).await;
let self_signing = SelfSigning {
inner: self_signing,
public_key: public_key.into(),
};
Self {
user_id: Arc::new(user_id),
shared: Arc::new(AtomicBool::new(false)),
master_key: Arc::new(Mutex::new(Some(master))),
self_signing_key: Arc::new(Mutex::new(Some(self_signing))),
user_signing_key: Arc::new(Mutex::new(Some(user))),
}
}
pub fn mark_as_shared(&self) {
self.shared.store(true, Ordering::SeqCst)
}
pub fn shared(&self) -> bool {
self.shared.load(Ordering::SeqCst)
}
pub async fn pickle(
&self,
pickle_key: &[u8],
) -> Result<PickledCrossSigningIdentity, JsonError> {
let master_key = if let Some(m) = self.master_key.lock().await.as_ref() {
Some(m.pickle(pickle_key).await)
} else {
None
};
let self_signing_key = if let Some(m) = self.self_signing_key.lock().await.as_ref() {
Some(m.pickle(pickle_key).await)
} else {
None
};
let user_signing_key = if let Some(m) = self.user_signing_key.lock().await.as_ref() {
Some(m.pickle(pickle_key).await)
} else {
None
};
let pickle = PickledSignings {
master_key,
user_signing_key,
self_signing_key,
};
let pickle = serde_json::to_string(&pickle)?;
Ok(PickledCrossSigningIdentity {
user_id: self.user_id.as_ref().to_owned(),
shared: self.shared(),
pickle,
})
}
/// Restore the private cross signing identity from a pickle.
///
/// # Panic
///
/// Panics if the pickle_key isn't 32 bytes long.
pub async fn from_pickle(
pickle: PickledCrossSigningIdentity,
pickle_key: &[u8],
) -> Result<Self, SigningError> {
let signings: PickledSignings = serde_json::from_str(&pickle.pickle)?;
let master = if let Some(m) = signings.master_key {
Some(MasterSigning::from_pickle(m, pickle_key)?)
} else {
None
};
let self_signing = if let Some(s) = signings.self_signing_key {
Some(SelfSigning::from_pickle(s, pickle_key)?)
} else {
None
};
let user_signing = if let Some(u) = signings.user_signing_key {
Some(UserSigning::from_pickle(u, pickle_key)?)
} else {
None
};
Ok(Self {
user_id: Arc::new(pickle.user_id),
shared: Arc::new(AtomicBool::from(pickle.shared)),
master_key: Arc::new(Mutex::new(master)),
self_signing_key: Arc::new(Mutex::new(self_signing)),
user_signing_key: Arc::new(Mutex::new(user_signing)),
})
}
pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest {
let master_key = self
.master_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let user_signing_key = self
.user_signing_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let self_signing_key = self
.self_signing_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
UploadSigningKeysRequest {
master_key,
user_signing_key,
self_signing_key,
}
}
}
#[cfg(test)]
mod test {
use crate::olm::ReadOnlyAccount;
use super::{PrivateCrossSigningIdentity, Signing};
use matrix_sdk_common::identifiers::{user_id, UserId};
use matrix_sdk_test::async_test;
fn user_id() -> UserId {
user_id!("@example:localhost")
}
fn pickle_key() -> &'static [u8] {
&[0u8; 32]
}
#[test]
fn signing_creation() {
let signing = Signing::new();
assert!(!signing.public_key().as_str().is_empty());
}
#[async_test]
async fn signature_verification() {
let signing = Signing::new();
let message = "Hello world";
let signature = signing.sign(message).await;
assert!(signing.verify(message, &signature).await.is_ok());
}
#[async_test]
async fn pickling_signing() {
let signing = Signing::new();
let pickled = signing.pickle(pickle_key()).await;
let unpickled = Signing::from_pickle(pickled, pickle_key()).unwrap();
assert_eq!(signing.public_key(), unpickled.public_key());
}
#[async_test]
async fn private_identity_creation() {
let identity = PrivateCrossSigningIdentity::new(user_id()).await;
let master_key = identity.master_key.lock().await;
let master_key = master_key.as_ref().unwrap();
assert!(master_key
.public_key
.verify_subkey(
&identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.is_ok());
assert!(master_key
.public_key
.verify_subkey(
&identity
.user_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.is_ok());
}
#[async_test]
async fn identity_pickling() {
let identity = PrivateCrossSigningIdentity::new(user_id()).await;
let pickled = identity.pickle(pickle_key()).await.unwrap();
let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key())
.await
.unwrap();
assert_eq!(identity.user_id, unpickled.user_id);
assert_eq!(
&*identity.master_key.lock().await,
&*unpickled.master_key.lock().await
);
assert_eq!(
&*identity.user_signing_key.lock().await,
&*unpickled.user_signing_key.lock().await
);
assert_eq!(
&*identity.self_signing_key.lock().await,
&*unpickled.self_signing_key.lock().await
);
}
#[async_test]
async fn private_identity_signed_by_accound() {
let account = ReadOnlyAccount::new(&user_id(), "DEVICEID".into());
let (identity, _) = PrivateCrossSigningIdentity::new_with_account(&account).await;
let master = identity.master_key.lock().await;
let master = master.as_ref().unwrap();
assert!(!master.public_key.signatures().is_empty());
}
}

View File

@ -20,6 +20,8 @@ use matrix_sdk_common::{
claim_keys::Response as KeysClaimResponse, claim_keys::Response as KeysClaimResponse,
get_keys::Response as KeysQueryResponse, get_keys::Response as KeysQueryResponse,
upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse}, upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse},
upload_signing_keys::Response as SigningKeysUploadResponse,
CrossSigningKey,
}, },
to_device::{send_event_to_device::Response as ToDeviceResponse, DeviceIdOrAllDevices}, to_device::{send_event_to_device::Response as ToDeviceResponse, DeviceIdOrAllDevices},
}, },
@ -56,6 +58,21 @@ impl ToDeviceRequest {
} }
} }
/// Request that will publish a cross signing identity.
///
/// This uploads the public cross signing key triplet.
#[derive(Debug, Clone)]
pub struct UploadSigningKeysRequest {
/// The user's master key.
pub master_key: Option<CrossSigningKey>,
/// The user's self-signing key. Must be signed with the accompanied master, or by the
/// user's most recently uploaded master key if no master key is included in the request.
pub self_signing_key: Option<CrossSigningKey>,
/// The user's user-signing key. Must be signed with the accompanied master, or by the
/// user's most recently uploaded master key if no master key is included in the request.
pub user_signing_key: Option<CrossSigningKey>,
}
/// Customized version of `ruma_client_api::r0::keys::get_keys::Request`, without any references. /// Customized version of `ruma_client_api::r0::keys::get_keys::Request`, without any references.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct KeysQueryRequest { pub struct KeysQueryRequest {
@ -141,6 +158,9 @@ pub enum IncomingResponse<'a> {
/// The key claiming requests, giving us new one-time keys of other users so /// The key claiming requests, giving us new one-time keys of other users so
/// new Olm sessions can be created. /// new Olm sessions can be created.
KeysClaim(&'a KeysClaimResponse), KeysClaim(&'a KeysClaimResponse),
/// The cross signing keys upload response, marking our private cross
/// signing identity as shared.
SigningKeysUpload(&'a SigningKeysUploadResponse),
} }
impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> { impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> {

View File

@ -1,187 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::BTreeMap, time::Duration};
use matrix_sdk_common::{
api::r0::keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
assign,
identifiers::{DeviceKeyAlgorithm, UserId},
uuid::Uuid,
};
use tracing::{error, info, warn};
use crate::{error::OlmResult, key_request::KeyRequestMachine, olm::Account, store::Store};
#[derive(Debug, Clone)]
pub(crate) struct SessionManager {
account: Account,
store: Store,
key_request_machine: KeyRequestMachine,
}
impl SessionManager {
const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
pub fn new(account: Account, key_request_machine: KeyRequestMachine, store: Store) -> Self {
Self {
account,
store,
key_request_machine,
}
}
/// Get the a key claiming request for the user/device pairs that we are
/// missing Olm sessions for.
///
/// Returns None if no key claiming request needs to be sent out.
///
/// Sessions need to be established between devices so group sessions for a
/// room can be shared with them.
///
/// This should be called every time a group session needs to be shared as
/// well as between sync calls. After a sync some devices may request room
/// keys without us having a valid Olm session with them, making it
/// impossible to server the room key request, thus it's necessary to check
/// for missing sessions between sync as well.
///
/// **Note**: Care should be taken that only one such request at a time is
/// in flight, e.g. using a lock.
///
/// The response of a successful key claiming requests needs to be passed to
/// the `OlmMachine` with the [`receive_keys_claim_response`].
///
/// # Arguments
///
/// `users` - The list of users that we should check if we lack a session
/// with one of their devices. This can be an empty iterator when calling
/// this method between sync requests.
///
/// [`receive_keys_claim_response`]: #method.receive_keys_claim_response
pub async fn get_missing_sessions(
&self,
users: &mut impl Iterator<Item = &UserId>,
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
let mut missing = BTreeMap::new();
// Add the list of devices that the user wishes to establish sessions
// right now.
for user_id in users {
let user_devices = self.store.get_user_devices(user_id).await?;
for device in user_devices.devices() {
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
k
} else {
continue;
};
let sessions = self.store.get_sessions(sender_key).await?;
let is_missing = if let Some(sessions) = sessions {
sessions.lock().await.is_empty()
} else {
true
};
if is_missing {
missing
.entry(user_id.to_owned())
.or_insert_with(BTreeMap::new)
.insert(
device.device_id().into(),
DeviceKeyAlgorithm::SignedCurve25519,
);
}
}
}
// Add the list of sessions that for some reason automatically need to
// create an Olm session.
for item in self.key_request_machine.users_for_key_claim().iter() {
let user = item.key();
for device_id in item.value().iter() {
missing
.entry(user.to_owned())
.or_insert_with(BTreeMap::new)
.insert(device_id.to_owned(), DeviceKeyAlgorithm::SignedCurve25519);
}
}
if missing.is_empty() {
Ok(None)
} else {
Ok(Some((
Uuid::new_v4(),
assign!(KeysClaimRequest::new(missing), {
timeout: Some(Self::KEY_CLAIM_TIMEOUT),
}),
)))
}
}
/// Receive a successful key claim response and create new Olm sessions with
/// the claimed keys.
///
/// # Arguments
///
/// * `response` - The response containing the claimed one-time keys.
pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
// TODO log the failures here
for (user_id, user_devices) in &response.one_time_keys {
for (device_id, key_map) in user_devices {
let device = match self.store.get_readonly_device(&user_id, device_id).await {
Ok(Some(d)) => d,
Ok(None) => {
warn!(
"Tried to create an Olm session for {} {}, but the device is unknown",
user_id, device_id
);
continue;
}
Err(e) => {
warn!(
"Tried to create an Olm session for {} {}, but \
can't fetch the device from the store {:?}",
user_id, device_id, e
);
continue;
}
};
info!("Creating outbound Session for {} {}", user_id, device_id);
let session = match self.account.create_outbound_session(device, &key_map).await {
Ok(s) => s,
Err(e) => {
warn!("{:?}", e);
continue;
}
};
if let Err(e) = self.store.save_sessions(&[session]).await {
error!("Failed to store newly created Olm session {}", e);
continue;
}
// TODO if this session was created because a previous one was
// wedged queue up a dummy event to be sent out.
self.key_request_machine.retry_keyshare(&user_id, device_id);
}
}
Ok(())
}
}

View File

@ -28,8 +28,8 @@ use tracing::{debug, info};
use crate::{ use crate::{
error::{EventError, MegolmResult, OlmResult}, error::{EventError, MegolmResult, OlmResult},
olm::{Account, OutboundGroupSession}, olm::{Account, InboundGroupSession, OutboundGroupSession},
store::Store, store::{Changes, Store},
Device, EncryptionSettings, OlmError, ToDeviceRequest, Device, EncryptionSettings, OlmError, ToDeviceRequest,
}; };
@ -140,19 +140,17 @@ impl GroupSessionManager {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
) -> OlmResult<()> { ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
let (outbound, inbound) = self let (outbound, inbound) = self
.account .account
.create_group_session_pair(room_id, settings) .create_group_session_pair(room_id, settings)
.await .await
.map_err(|_| EventError::UnsupportedAlgorithm)?; .map_err(|_| EventError::UnsupportedAlgorithm)?;
let _ = self.store.save_inbound_group_sessions(&[inbound]).await?;
let _ = self let _ = self
.outbound_group_sessions .outbound_group_sessions
.insert(room_id.to_owned(), outbound); .insert(room_id.to_owned(), outbound.clone());
Ok(()) Ok((outbound, inbound))
} }
/// Get to-device requests to share a group session with users in a room. /// Get to-device requests to share a group session with users in a room.
@ -169,13 +167,12 @@ impl GroupSessionManager {
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> { ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.create_outbound_group_session(room_id, encryption_settings.into()) let mut changes = Changes::default();
.await?;
let session = self.outbound_group_sessions.get(room_id).unwrap();
if session.shared() { let (session, inbound_session) = self
panic!("Session is already shared"); .create_outbound_group_session(room_id, encryption_settings.into())
} .await?;
changes.inbound_group_sessions.push(inbound_session);
let mut devices: Vec<Device> = Vec::new(); let mut devices: Vec<Device> = Vec::new();
@ -196,7 +193,7 @@ impl GroupSessionManager {
.encrypt(EventType::RoomKey, key_content.clone()) .encrypt(EventType::RoomKey, key_content.clone())
.await; .await;
let encrypted = match encrypted { let (used_session, encrypted) = match encrypted {
Ok(c) => c, Ok(c) => c,
Err(OlmError::MissingSession) Err(OlmError::MissingSession)
| Err(OlmError::EventError(EventError::MissingSenderKey)) => { | Err(OlmError::EventError(EventError::MissingSenderKey)) => {
@ -205,6 +202,8 @@ impl GroupSessionManager {
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
changes.sessions.push(used_session);
messages messages
.entry(device.user_id().clone()) .entry(device.user_id().clone())
.or_insert_with(BTreeMap::new) .or_insert_with(BTreeMap::new)
@ -237,6 +236,8 @@ impl GroupSessionManager {
session.mark_as_shared(); session.mark_as_shared();
} }
self.store.save_changes(changes).await?;
Ok(requests) Ok(requests)
} }
} }

View File

@ -0,0 +1,19 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod group_sessions;
mod sessions;
pub(crate) use group_sessions::GroupSessionManager;
pub(crate) use sessions::SessionManager;

View File

@ -0,0 +1,498 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use dashmap::{DashMap, DashSet};
use matrix_sdk_common::{
api::r0::{
keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
to_device::DeviceIdOrAllDevices,
},
assign,
events::EventType,
identifiers::{DeviceId, DeviceIdBox, DeviceKeyAlgorithm, UserId},
uuid::Uuid,
};
use serde_json::{json, value::to_raw_value};
use tracing::{error, info, warn};
use crate::{
error::OlmResult,
key_request::KeyRequestMachine,
olm::Account,
requests::{OutgoingRequest, ToDeviceRequest},
store::{Changes, Result as StoreResult, Store},
ReadOnlyDevice,
};
#[derive(Debug, Clone)]
pub(crate) struct SessionManager {
account: Account,
store: Store,
/// A map of user/devices that we need to automatically claim keys for.
/// Submodules can insert user/device pairs into this map and the
/// user/device paris will be added to the list of users when
/// [`get_missing_sessions`](#method.get_missing_sessions) is called.
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
wedged_devices: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
key_request_machine: KeyRequestMachine,
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
}
impl SessionManager {
const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
pub fn new(
account: Account,
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
key_request_machine: KeyRequestMachine,
store: Store,
) -> Self {
Self {
account,
store,
key_request_machine,
users_for_key_claim,
wedged_devices: Arc::new(DashMap::new()),
outgoing_to_device_requests: Arc::new(DashMap::new()),
}
}
/// Mark the outgoing request as sent.
pub fn mark_outgoing_request_as_sent(&self, id: &Uuid) {
self.outgoing_to_device_requests.remove(id);
}
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
if let Some(device) = self
.store
.get_device_from_curve_key(sender, curve_key)
.await?
{
let sessions = device.get_sessions().await?;
if let Some(sessions) = sessions {
let mut sessions = sessions.lock().await;
sessions.sort_by_key(|s| s.creation_time.clone());
let session = sessions.get(0);
if let Some(session) = session {
if session.creation_time.elapsed() > Self::UNWEDGING_INTERVAL {
self.users_for_key_claim
.entry(device.user_id().clone())
.or_insert_with(DashSet::new)
.insert(device.device_id().into());
self.wedged_devices
.entry(device.user_id().to_owned())
.or_insert_with(DashSet::new)
.insert(device.device_id().into());
}
}
}
}
Ok(())
}
#[allow(dead_code)]
pub fn is_device_wedged(&self, device: &ReadOnlyDevice) -> bool {
self.wedged_devices
.get(device.user_id())
.map(|d| d.contains(device.device_id()))
.unwrap_or(false)
}
/// Check if the session was created to unwedge a Device.
///
/// If the device was wedged this will queue up a dummy to-device message.
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self
.wedged_devices
.get(user_id)
.map(|d| d.remove(device_id))
.flatten()
.is_some()
{
if let Some(device) = self.store.get_device(user_id, device_id).await? {
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
messages
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest {
request_id: id,
request: Arc::new(
ToDeviceRequest {
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
}
.into(),
),
};
self.outgoing_to_device_requests.insert(id, request);
}
}
Ok(())
}
/// Get the a key claiming request for the user/device pairs that we are
/// missing Olm sessions for.
///
/// Returns None if no key claiming request needs to be sent out.
///
/// Sessions need to be established between devices so group sessions for a
/// room can be shared with them.
///
/// This should be called every time a group session needs to be shared as
/// well as between sync calls. After a sync some devices may request room
/// keys without us having a valid Olm session with them, making it
/// impossible to server the room key request, thus it's necessary to check
/// for missing sessions between sync as well.
///
/// **Note**: Care should be taken that only one such request at a time is
/// in flight, e.g. using a lock.
///
/// The response of a successful key claiming requests needs to be passed to
/// the `OlmMachine` with the [`receive_keys_claim_response`].
///
/// # Arguments
///
/// `users` - The list of users that we should check if we lack a session
/// with one of their devices. This can be an empty iterator when calling
/// this method between sync requests.
///
/// [`receive_keys_claim_response`]: #method.receive_keys_claim_response
pub async fn get_missing_sessions(
&self,
users: &mut impl Iterator<Item = &UserId>,
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
let mut missing = BTreeMap::new();
// Add the list of devices that the user wishes to establish sessions
// right now.
for user_id in users {
let user_devices = self.store.get_user_devices(user_id).await?;
for device in user_devices.devices() {
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
k
} else {
continue;
};
let sessions = self.store.get_sessions(sender_key).await?;
let is_missing = if let Some(sessions) = sessions {
sessions.lock().await.is_empty()
} else {
true
};
if is_missing {
missing
.entry(user_id.to_owned())
.or_insert_with(BTreeMap::new)
.insert(
device.device_id().into(),
DeviceKeyAlgorithm::SignedCurve25519,
);
}
}
}
// Add the list of sessions that for some reason automatically need to
// create an Olm session.
for item in self.users_for_key_claim.iter() {
let user = item.key();
for device_id in item.value().iter() {
missing
.entry(user.to_owned())
.or_insert_with(BTreeMap::new)
.insert(device_id.to_owned(), DeviceKeyAlgorithm::SignedCurve25519);
}
}
if missing.is_empty() {
Ok(None)
} else {
Ok(Some((
Uuid::new_v4(),
assign!(KeysClaimRequest::new(missing), {
timeout: Some(Self::KEY_CLAIM_TIMEOUT),
}),
)))
}
}
/// Receive a successful key claim response and create new Olm sessions with
/// the claimed keys.
///
/// # Arguments
///
/// * `response` - The response containing the claimed one-time keys.
pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
// TODO log the failures here
let mut changes = Changes::default();
for (user_id, user_devices) in &response.one_time_keys {
for (device_id, key_map) in user_devices {
let device = match self.store.get_readonly_device(&user_id, device_id).await {
Ok(Some(d)) => d,
Ok(None) => {
warn!(
"Tried to create an Olm session for {} {}, but the device is unknown",
user_id, device_id
);
continue;
}
Err(e) => {
warn!(
"Tried to create an Olm session for {} {}, but \
can't fetch the device from the store {:?}",
user_id, device_id, e
);
continue;
}
};
info!("Creating outbound Session for {} {}", user_id, device_id);
let session = match self.account.create_outbound_session(device, &key_map).await {
Ok(s) => s,
Err(e) => {
warn!("Error creating new outbound session {:?}", e);
continue;
}
};
changes.sessions.push(session);
self.key_request_machine.retry_keyshare(&user_id, device_id);
if let Err(e) = self.check_if_unwedged(&user_id, device_id).await {
error!(
"Error while treating an unwedged device {} {} {:?}",
user_id, device_id, e
);
}
}
}
Ok(self.store.save_changes(changes).await?)
}
}
#[cfg(test)]
mod test {
use dashmap::DashMap;
use matrix_sdk_common::locks::Mutex;
use std::{collections::BTreeMap, sync::Arc};
use matrix_sdk_common::{
api::r0::keys::claim_keys::Response as KeyClaimResponse,
identifiers::{user_id, DeviceIdBox, DeviceKeyAlgorithm, UserId},
instant::{Duration, Instant},
};
use matrix_sdk_test::async_test;
use super::SessionManager;
use crate::{
identities::ReadOnlyDevice,
key_request::KeyRequestMachine,
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
store::{CryptoStore, MemoryStore, Store},
verification::VerificationMachine,
};
fn user_id() -> UserId {
user_id!("@example:localhost")
}
fn device_id() -> DeviceIdBox {
"DEVICEID".into()
}
fn bob_account() -> ReadOnlyAccount {
ReadOnlyAccount::new(&user_id!("@bob:localhost"), "BOBDEVICE".into())
}
async fn session_manager() -> SessionManager {
let user_id = user_id();
let device_id = device_id();
let outbound_sessions = Arc::new(DashMap::new());
let users_for_key_claim = Arc::new(DashMap::new());
let account = ReadOnlyAccount::new(&user_id, &device_id);
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
store.save_account(account.clone()).await.unwrap();
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
user_id.clone(),
)));
let verification = VerificationMachine::new(account.clone(), identity, store.clone());
let user_id = Arc::new(user_id);
let device_id = Arc::new(device_id);
let store = Store::new(user_id.clone(), store, verification);
let account = Account {
inner: account,
store: store.clone(),
};
let key_request = KeyRequestMachine::new(
user_id,
device_id,
store.clone(),
outbound_sessions,
users_for_key_claim.clone(),
);
SessionManager::new(account, users_for_key_claim, key_request, store)
}
#[async_test]
async fn session_creation() {
let manager = session_manager().await;
let bob = bob_account();
let bob_device = ReadOnlyDevice::from_account(&bob).await;
manager.store.save_devices(&[bob_device]).await.unwrap();
let (_, request) = manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
.await
.unwrap()
.unwrap();
assert!(request.one_time_keys.contains_key(bob.user_id()));
bob.generate_one_time_keys_helper(1).await;
let one_time = bob.signed_one_time_keys_helper().await.unwrap();
bob.mark_keys_as_published().await;
let mut one_time_keys = BTreeMap::new();
one_time_keys
.entry(bob.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(bob.device_id().into(), one_time);
let response = KeyClaimResponse {
failures: BTreeMap::new(),
one_time_keys,
};
manager
.receive_keys_claim_response(&response)
.await
.unwrap();
assert!(manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
.await
.unwrap()
.is_none());
}
// This test doesn't run on macos because we're modifying the session
// creation time so we can get around the UNWEDGING_INTERVAL.
#[async_test]
#[cfg(not(target_os = "macos"))]
async fn session_unwedging() {
let manager = session_manager().await;
let bob = bob_account();
let (_, mut session) = bob.create_session_for(&manager.account).await;
let bob_device = ReadOnlyDevice::from_account(&bob).await;
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601));
manager
.store
.save_devices(&[bob_device.clone()])
.await
.unwrap();
manager.store.save_sessions(&[session]).await.unwrap();
assert!(manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
.await
.unwrap()
.is_none());
let curve_key = bob_device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap();
assert!(!manager.users_for_key_claim.contains_key(bob.user_id()));
assert!(!manager.is_device_wedged(&bob_device));
manager
.mark_device_as_wedged(bob_device.user_id(), &curve_key)
.await
.unwrap();
assert!(manager.is_device_wedged(&bob_device));
assert!(manager.users_for_key_claim.contains_key(bob.user_id()));
let (_, request) = manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
.await
.unwrap()
.unwrap();
assert!(request.one_time_keys.contains_key(bob.user_id()));
bob.generate_one_time_keys_helper(1).await;
let one_time = bob.signed_one_time_keys_helper().await.unwrap();
bob.mark_keys_as_published().await;
let mut one_time_keys = BTreeMap::new();
one_time_keys
.entry(bob.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(bob.device_id().into(), one_time);
let response = KeyClaimResponse {
failures: BTreeMap::new(),
one_time_keys,
};
assert!(manager.outgoing_to_device_requests.is_empty());
manager
.receive_keys_claim_response(&response)
.await
.unwrap();
assert!(!manager.is_device_wedged(&bob_device));
assert!(manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
.await
.unwrap()
.is_none());
assert!(!manager.outgoing_to_device_requests.is_empty())
}
}

View File

@ -19,9 +19,9 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use dashmap::{DashMap, ReadOnlyView}; use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId}, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
locks::Mutex, locks::Mutex,
}; };
@ -145,29 +145,6 @@ pub struct DeviceStore {
entries: Arc<DashMap<UserId, DashMap<Box<DeviceId>, ReadOnlyDevice>>>, entries: Arc<DashMap<UserId, DashMap<Box<DeviceId>, ReadOnlyDevice>>>,
} }
/// A read only view over all devices belonging to a user.
#[derive(Debug)]
pub struct ReadOnlyUserDevices {
entries: ReadOnlyView<Box<DeviceId>, ReadOnlyDevice>,
}
impl ReadOnlyUserDevices {
/// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries.get(device_id).cloned()
}
/// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> {
self.entries.keys().map(|id| id.as_ref())
}
/// Iterator over all the devices of the user devices.
pub fn devices(&self) -> impl Iterator<Item = &ReadOnlyDevice> {
self.entries.values()
}
}
impl DeviceStore { impl DeviceStore {
/// Create a new empty device store. /// Create a new empty device store.
pub fn new() -> Self { pub fn new() -> Self {
@ -206,15 +183,13 @@ impl DeviceStore {
} }
/// Get a read-only view over all devices of the given user. /// Get a read-only view over all devices of the given user.
pub fn user_devices(&self, user_id: &UserId) -> ReadOnlyUserDevices { pub fn user_devices(&self, user_id: &UserId) -> HashMap<DeviceIdBox, ReadOnlyDevice> {
ReadOnlyUserDevices { self.entries
entries: self
.entries
.entry(user_id.clone()) .entry(user_id.clone())
.or_insert_with(DashMap::new) .or_insert_with(DashMap::new)
.clone() .iter()
.into_read_only(), .map(|i| (i.key().to_owned(), i.value().clone()))
} .collect()
} }
} }
@ -305,12 +280,12 @@ mod test {
let user_devices = store.user_devices(device.user_id()); let user_devices = store.user_devices(device.user_id());
assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device); assert_eq!(user_devices.values().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap(); let loaded_device = user_devices.get(device.device_id()).unwrap();
assert_eq!(device, loaded_device); assert_eq!(&device, loaded_device);
store.remove(device.user_id(), device.device_id()); store.remove(device.user_id(), device.device_id());

View File

@ -12,20 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashSet, sync::Arc}; use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId}, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
locks::Mutex, locks::Mutex,
}; };
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
use super::{ use super::{
caches::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore}, caches::{DeviceStore, GroupSessionStore, SessionStore},
CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
};
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
olm::PrivateCrossSigningIdentity,
}; };
use crate::identities::{ReadOnlyDevice, UserIdentities};
/// An in-memory only store that will forget all the E2EE key once it's dropped. /// An in-memory only store that will forget all the E2EE key once it's dropped.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -58,6 +64,30 @@ impl MemoryStore {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }
pub(crate) async fn save_devices(&self, mut devices: Vec<ReadOnlyDevice>) {
for device in devices.drain(..) {
let _ = self.devices.add(device);
}
}
async fn delete_devices(&self, mut devices: Vec<ReadOnlyDevice>) {
for device in devices.drain(..) {
let _ = self.devices.remove(device.user_id(), device.device_id());
}
}
async fn save_sessions(&self, mut sessions: Vec<Session>) {
for session in sessions.drain(..) {
let _ = self.sessions.add(session.clone()).await;
}
}
async fn save_inbound_group_sessions(&self, mut sessions: Vec<InboundGroupSession>) {
for session in sessions.drain(..) {
self.inbound_group_sessions.add(session);
}
}
} }
#[async_trait] #[async_trait]
@ -70,9 +100,24 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { async fn save_changes(&self, mut changes: Changes) -> Result<()> {
for session in sessions { self.save_sessions(changes.sessions).await;
let _ = self.sessions.add(session.clone()).await; self.save_inbound_group_sessions(changes.inbound_group_sessions)
.await;
self.save_devices(changes.devices.new).await;
self.save_devices(changes.devices.changed).await;
self.delete_devices(changes.devices.deleted).await;
for identity in changes
.identities
.new
.drain(..)
.chain(changes.identities.changed)
{
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
} }
Ok(()) Ok(())
@ -82,14 +127,6 @@ impl CryptoStore for MemoryStore {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
for session in sessions {
self.inbound_group_sessions.add(session.clone());
}
Ok(())
}
async fn get_inbound_group_session( async fn get_inbound_group_session(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
@ -148,37 +185,18 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.get(user_id, device_id)) Ok(self.devices.get(user_id, device_id))
} }
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> { async fn get_user_devices(
let _ = self.devices.remove(device.user_id(), device.device_id()); &self,
Ok(()) user_id: &UserId,
} ) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> {
Ok(self.devices.user_devices(user_id)) Ok(self.devices.user_devices(user_id))
} }
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
for device in devices {
let _ = self.devices.add(device.clone());
}
Ok(())
}
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> { async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self.identities.get(user_id).map(|i| i.clone())) Ok(self.identities.get(user_id).map(|i| i.clone()))
} }
async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()> {
for identity in identities {
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
}
Ok(())
}
async fn save_value(&self, key: String, value: String) -> Result<()> { async fn save_value(&self, key: String, value: String) -> Result<()> {
self.values.insert(key, value); self.values.insert(key, value);
Ok(()) Ok(())
@ -192,6 +210,14 @@ impl CryptoStore for MemoryStore {
async fn get_value(&self, key: &str) -> Result<Option<String>> { async fn get_value(&self, key: &str) -> Result<Option<String>> {
Ok(self.values.get(key).map(|v| v.to_owned())) Ok(self.values.get(key).map(|v| v.to_owned()))
} }
async fn save_identity(&self, _: PrivateCrossSigningIdentity) -> Result<()> {
Ok(())
}
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
Ok(None)
}
} }
#[cfg(test)] #[cfg(test)]
@ -211,7 +237,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none()); assert!(store.load_account().await.unwrap().is_none());
store.save_account(account).await.unwrap(); store.save_account(account).await.unwrap();
store.save_sessions(&[session.clone()]).await.unwrap(); store.save_sessions(vec![session.clone()]).await;
let sessions = store let sessions = store
.get_sessions(&session.sender_key) .get_sessions(&session.sender_key)
@ -244,9 +270,8 @@ mod test {
let store = MemoryStore::new(); let store = MemoryStore::new();
let _ = store let _ = store
.save_inbound_group_sessions(&[inbound.clone()]) .save_inbound_group_sessions(vec![inbound.clone()])
.await .await;
.unwrap();
let loaded_session = store let loaded_session = store
.get_inbound_group_session(&room_id, "test_key", outbound.session_id()) .get_inbound_group_session(&room_id, "test_key", outbound.session_id())
@ -261,7 +286,7 @@ mod test {
let device = get_device(); let device = get_device();
let store = MemoryStore::new(); let store = MemoryStore::new();
store.save_devices(&[device.clone()]).await.unwrap(); store.save_devices(vec![device.clone()]).await;
let loaded_device = store let loaded_device = store
.get_device(device.user_id(), device.device_id()) .get_device(device.user_id(), device.device_id())
@ -273,14 +298,14 @@ mod test {
let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device); assert_eq!(user_devices.values().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap(); let loaded_device = user_devices.get(device.device_id()).unwrap();
assert_eq!(device, loaded_device); assert_eq!(&device, loaded_device);
store.delete_device(device.clone()).await.unwrap(); store.delete_devices(vec![device.clone()]).await;
assert!(store assert!(store
.get_device(device.user_id(), device.device_id()) .get_device(device.user_id(), device.device_id())
.await .await

View File

@ -39,23 +39,30 @@
pub mod caches; pub mod caches;
mod memorystore; mod memorystore;
mod pickle_key;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
pub(crate) mod sqlite; pub(crate) mod sqlite;
use caches::ReadOnlyUserDevices; use matrix_sdk_common::identifiers::DeviceIdBox;
pub use memorystore::MemoryStore; pub use memorystore::MemoryStore;
pub use pickle_key::{EncryptedPickleKey, PickleKey};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
pub use sqlite::SqliteStore; pub use sqlite::SqliteStore;
use std::{collections::HashSet, fmt::Debug, io::Error as IoError, ops::Deref, sync::Arc}; use std::{
collections::{HashMap, HashSet},
fmt::Debug,
io::Error as IoError,
ops::Deref,
sync::Arc,
};
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use thiserror::Error; use thiserror::Error;
use url::ParseError;
#[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] #[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))]
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -63,7 +70,9 @@ use url::ParseError;
use sqlx::Error as SqlxError; use sqlx::Error as SqlxError;
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceId, Error as IdentifierValidationError, RoomId, UserId}, identifiers::{
DeviceId, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId,
},
locks::Mutex, locks::Mutex,
}; };
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
@ -73,7 +82,7 @@ use matrix_sdk_common_macros::send_sync;
use crate::{ use crate::{
error::SessionUnpicklingError, error::SessionUnpicklingError,
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
olm::{InboundGroupSession, ReadOnlyAccount, Session}, olm::{InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session},
verification::VerificationMachine, verification::VerificationMachine,
}; };
@ -93,6 +102,31 @@ pub(crate) struct Store {
verification_machine: VerificationMachine, verification_machine: VerificationMachine,
} }
#[derive(Debug, Default)]
#[allow(missing_docs)]
pub struct Changes {
pub account: Option<ReadOnlyAccount>,
pub sessions: Vec<Session>,
pub inbound_group_sessions: Vec<InboundGroupSession>,
pub identities: IdentityChanges,
pub devices: DeviceChanges,
}
#[derive(Debug, Clone, Default)]
#[allow(missing_docs)]
pub struct IdentityChanges {
pub new: Vec<UserIdentities>,
pub changed: Vec<UserIdentities>,
}
#[derive(Debug, Clone, Default)]
#[allow(missing_docs)]
pub struct DeviceChanges {
pub new: Vec<ReadOnlyDevice>,
pub changed: Vec<ReadOnlyDevice>,
pub deleted: Vec<ReadOnlyDevice>,
}
impl Store { impl Store {
pub fn new( pub fn new(
user_id: Arc<UserId>, user_id: Arc<UserId>,
@ -114,10 +148,61 @@ impl Store {
self.inner.get_device(user_id, device_id).await self.inner.get_device(user_id, device_id).await
} }
pub async fn get_readonly_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> { pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes {
sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await
}
#[cfg(test)]
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let changes = Changes {
devices: DeviceChanges {
changed: devices.to_vec(),
..Default::default()
},
..Default::default()
};
self.save_changes(changes).await
}
#[cfg(test)]
pub async fn save_inbound_group_sessions(
&self,
sessions: &[InboundGroupSession],
) -> Result<()> {
let changes = Changes {
inbound_group_sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await
}
pub async fn get_readonly_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
self.inner.get_user_devices(user_id).await self.inner.get_user_devices(user_id).await
} }
pub async fn get_device_from_curve_key(
&self,
user_id: &UserId,
curve_key: &str,
) -> Result<Option<Device>> {
self.get_user_devices(user_id).await.map(|d| {
d.devices().find(|d| {
d.get_key(DeviceKeyAlgorithm::Curve25519)
.map_or(false, |k| k == curve_key)
})
})
}
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> { pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?; let devices = self.inner.get_user_devices(user_id).await?;
@ -223,6 +308,10 @@ pub enum CryptoStoreError {
#[error(transparent)] #[error(transparent)]
SessionUnpickling(#[from] SessionUnpicklingError), SessionUnpickling(#[from] SessionUnpicklingError),
/// Failed to decrypt an pickled object.
#[error("An object failed to be decrypted while unpickling")]
UnpicklingError,
/// A Matirx identifier failed to be validated. /// A Matirx identifier failed to be validated.
#[error(transparent)] #[error(transparent)]
IdentifierValidation(#[from] IdentifierValidationError), IdentifierValidation(#[from] IdentifierValidationError),
@ -230,10 +319,6 @@ pub enum CryptoStoreError {
/// The store failed to (de)serialize a data type. /// The store failed to (de)serialize a data type.
#[error(transparent)] #[error(transparent)]
Serialization(#[from] SerdeError), Serialization(#[from] SerdeError),
/// An error occurred while parsing an URL.
#[error(transparent)]
UrlParse(#[from] ParseError),
} }
/// Trait abstracting a store that the `OlmMachine` uses to store cryptographic /// Trait abstracting a store that the `OlmMachine` uses to store cryptographic
@ -252,12 +337,14 @@ pub trait CryptoStore: Debug {
/// * `account` - The account that should be stored. /// * `account` - The account that should be stored.
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>;
/// Save the given sessions in the store. /// TODO
/// async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()>;
/// # Arguments
/// /// TODO
/// * `session` - The sessions that should be stored. async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>>;
async fn save_sessions(&self, session: &[Session]) -> Result<()>;
/// TODO
async fn save_changes(&self, changes: Changes) -> Result<()>;
/// Get all the sessions that belong to the given sender key. /// Get all the sessions that belong to the given sender key.
/// ///
@ -266,13 +353,6 @@ pub trait CryptoStore: Debug {
/// * `sender_key` - The sender key that was used to establish the sessions. /// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>; async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
/// Save the given inbound group sessions in the store.
///
/// # Arguments
///
/// * `sessions` - The sessions that should be stored.
async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>;
/// Get the inbound group session from our store. /// Get the inbound group session from our store.
/// ///
/// # Arguments /// # Arguments
@ -312,20 +392,6 @@ pub trait CryptoStore: Debug {
/// * `dirty` - Should the user be also marked for a key query. /// * `dirty` - Should the user be also marked for a key query.
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>; async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>;
/// Save the given devices in the store.
///
/// # Arguments
///
/// * `device` - The device that should be stored.
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()>;
/// Delete the given device from the store.
///
/// # Arguments
///
/// * `device` - The device that should be stored.
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()>;
/// Get the device for the given user with the given device id. /// Get the device for the given user with the given device id.
/// ///
/// # Arguments /// # Arguments
@ -344,14 +410,10 @@ pub trait CryptoStore: Debug {
/// # Arguments /// # Arguments
/// ///
/// * `user_id` - The user for which we should get all the devices. /// * `user_id` - The user for which we should get all the devices.
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices>; async fn get_user_devices(
&self,
/// Save the given user identities in the store. user_id: &UserId,
/// ) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>>;
/// # Arguments
///
/// * `identities` - The identities that should be saved in the store.
async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()>;
/// Get the user identity that is attached to the given user id. /// Get the user identity that is attached to the given user id.
/// ///

View File

@ -0,0 +1,206 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryFrom;
use aes_gcm::{
aead::{generic_array::GenericArray, Aead, NewAead},
Aes256Gcm, Error as DecryptionError,
};
use getrandom::getrandom;
use hmac::Hmac;
use olm_rs::PicklingMode;
use pbkdf2::pbkdf2;
use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing};
use serde::{Deserialize, Serialize};
const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12;
const KDF_SALT_SIZE: usize = 32;
const KDF_ROUNDS: u32 = 10000;
/// Version specific info for the key derivation method that is used.
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub enum KdfInfo {
Pbkdf2 {
/// The number of PBKDF rounds that were used when deriving the AES key.
rounds: u32,
},
}
/// Version specific info for encryption method that is used to encrypt our
/// pickle key.
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub enum CipherTextInfo {
Aes256Gcm {
/// The nonce that was used to encrypt the ciphertext.
nonce: Vec<u8>,
/// The encrypted pickle key.
ciphertext: Vec<u8>,
},
}
/// An encrypted version of our pickle key, this can be safely stored in a
/// database.
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct EncryptedPickleKey {
/// Info about the key derivation method that was used to expand the
/// passphrase into an encryption key.
pub kdf_info: KdfInfo,
/// The ciphertext with it's accompanying additional data that is needed to
/// decrypt the pickle key.
pub ciphertext_info: CipherTextInfo,
/// The salt that was used when the passphrase was expanded into a AES key.
kdf_salt: Vec<u8>,
}
/// A pickle key that will be used to encrypt all the private keys for Olm.
///
/// Olm uses AES256 to encrypt accounts, sessions, inbound group sessions. We
/// also implement our own pickling for the cross-signing types using
/// AES256-GCM so the key sizes match.
#[derive(Debug, Zeroize, PartialEq)]
pub struct PickleKey {
aes256_key: Vec<u8>,
}
impl Default for PickleKey {
fn default() -> Self {
let mut key = vec![0u8; KEY_SIZE];
getrandom(&mut key).expect("Can't generate new pickle key");
Self { aes256_key: key }
}
}
impl TryFrom<Vec<u8>> for PickleKey {
type Error = ();
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
if value.len() != KEY_SIZE {
Err(())
} else {
Ok(Self { aes256_key: value })
}
}
}
impl PickleKey {
/// Generate a new random pickle key.
pub fn new() -> Self {
Default::default()
}
fn expand_key(passphrase: &str, salt: &[u8], rounds: u32) -> Zeroizing<Vec<u8>> {
let mut key = Zeroizing::from(vec![0u8; KEY_SIZE]);
pbkdf2::<Hmac<Sha256>>(passphrase.as_bytes(), &salt, rounds, &mut *key);
key
}
/// Get a `PicklingMode` version of this pickle key.
pub fn pickle_mode(&self) -> PicklingMode {
PicklingMode::Encrypted {
key: self.aes256_key.clone(),
}
}
/// Get the raw AES256 key.
pub fn key(&self) -> &[u8] {
&self.aes256_key
}
/// Encrypt and export our pickle key using the given passphrase.
///
/// # Arguments
///
/// * `passphrase` - The passphrase that should be used to encrypt the
/// pickle key.
pub fn encrypt(&self, passphrase: &str) -> EncryptedPickleKey {
let mut salt = vec![0u8; KDF_SALT_SIZE];
getrandom(&mut salt).expect("Can't generate new random pickle key");
let key = PickleKey::expand_key(passphrase, &salt, KDF_ROUNDS);
let key = GenericArray::from_slice(key.as_ref());
let cipher = Aes256Gcm::new(&key);
let mut nonce = vec![0u8; NONCE_SIZE];
getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key");
let ciphertext = cipher
.encrypt(
&GenericArray::from_slice(nonce.as_ref()),
self.aes256_key.as_slice(),
)
.expect("Can't encrypt pickle key");
EncryptedPickleKey {
kdf_info: KdfInfo::Pbkdf2 { rounds: KDF_ROUNDS },
kdf_salt: salt,
ciphertext_info: CipherTextInfo::Aes256Gcm { nonce, ciphertext },
}
}
/// Restore a pickle key from an encrypted export.
///
/// # Arguments
///
/// * `passphrase` - The passphrase that should be used to encrypt the
/// pickle key.
///
/// * `encrypted` - The exported and encrypted version of the pickle key.
pub fn from_encrypted(
passphrase: &str,
encrypted: EncryptedPickleKey,
) -> Result<Self, DecryptionError> {
let key = match encrypted.kdf_info {
KdfInfo::Pbkdf2 { rounds } => Self::expand_key(passphrase, &encrypted.kdf_salt, rounds),
};
let key = GenericArray::from_slice(key.as_ref());
let decrypted = match encrypted.ciphertext_info {
CipherTextInfo::Aes256Gcm { nonce, ciphertext } => {
let cipher = Aes256Gcm::new(&key);
let nonce = GenericArray::from_slice(&nonce);
cipher.decrypt(nonce, ciphertext.as_ref())?
}
};
Ok(Self {
aes256_key: decrypted,
})
}
}
#[cfg(test)]
mod test {
use super::PickleKey;
#[test]
fn generating() {
PickleKey::new();
}
#[test]
fn encrypting() {
let passphrase = "it's a secret to everybody";
let pickle_key = PickleKey::new();
let encrypted = pickle_key.encrypt(passphrase);
let decrypted = PickleKey::from_encrypted(passphrase, encrypted).unwrap();
assert_eq!(pickle_key, decrypted);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@ use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::locks::Mutex;
use tracing::{trace, warn}; use tracing::{trace, warn};
use matrix_sdk_common::{ use matrix_sdk_common::{
@ -26,6 +27,7 @@ use matrix_sdk_common::{
use super::sas::{content_to_request, Sas}; use super::sas::{content_to_request, Sas};
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity,
requests::{OutgoingRequest, ToDeviceRequest}, requests::{OutgoingRequest, ToDeviceRequest},
store::{CryptoStore, CryptoStoreError}, store::{CryptoStore, CryptoStoreError},
ReadOnlyAccount, ReadOnlyDevice, ReadOnlyAccount, ReadOnlyDevice,
@ -34,15 +36,21 @@ use crate::{
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct VerificationMachine { pub struct VerificationMachine {
account: ReadOnlyAccount, account: ReadOnlyAccount,
user_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
pub(crate) store: Arc<Box<dyn CryptoStore>>, pub(crate) store: Arc<Box<dyn CryptoStore>>,
verifications: Arc<DashMap<String, Sas>>, verifications: Arc<DashMap<String, Sas>>,
outgoing_to_device_messages: Arc<DashMap<Uuid, OutgoingRequest>>, outgoing_to_device_messages: Arc<DashMap<Uuid, OutgoingRequest>>,
} }
impl VerificationMachine { impl VerificationMachine {
pub(crate) fn new(account: ReadOnlyAccount, store: Arc<Box<dyn CryptoStore>>) -> Self { pub(crate) fn new(
account: ReadOnlyAccount,
identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
store: Arc<Box<dyn CryptoStore>>,
) -> Self {
Self { Self {
account, account,
user_identity: identity,
store, store,
verifications: Arc::new(DashMap::new()), verifications: Arc::new(DashMap::new()),
outgoing_to_device_messages: Arc::new(DashMap::new()), outgoing_to_device_messages: Arc::new(DashMap::new()),
@ -194,8 +202,7 @@ impl VerificationMachine {
self.receive_event_helper(&s, event); self.receive_event_helper(&s, event);
if s.is_done() { if s.is_done() {
if !s.mark_device_as_verified().await? { if let Some(r) = s.mark_as_done().await? {
if let Some(r) = s.cancel() {
self.outgoing_to_device_messages.insert( self.outgoing_to_device_messages.insert(
r.txn_id, r.txn_id,
OutgoingRequest { OutgoingRequest {
@ -204,9 +211,6 @@ impl VerificationMachine {
}, },
); );
} }
} else {
s.mark_identity_as_verified().await?;
}
} }
}; };
} }
@ -228,10 +232,12 @@ mod test {
use matrix_sdk_common::{ use matrix_sdk_common::{
events::AnyToDeviceEventContent, events::AnyToDeviceEventContent,
identifiers::{DeviceId, UserId}, identifiers::{DeviceId, UserId},
locks::Mutex,
}; };
use super::{Sas, VerificationMachine}; use super::{Sas, VerificationMachine};
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity,
requests::OutgoingRequests, requests::OutgoingRequests,
store::{CryptoStore, MemoryStore}, store::{CryptoStore, MemoryStore},
verification::test::{get_content_from_request, wrap_any_to_device_content}, verification::test::{get_content_from_request, wrap_any_to_device_content},
@ -258,18 +264,17 @@ mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
let store = MemoryStore::new(); let store = MemoryStore::new();
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let bob_store = MemoryStore::new();
let bob_device = ReadOnlyDevice::from_account(&bob).await; let bob_device = ReadOnlyDevice::from_account(&bob).await;
let alice_device = ReadOnlyDevice::from_account(&alice).await; let alice_device = ReadOnlyDevice::from_account(&alice).await;
store.save_devices(&[bob_device]).await.unwrap(); store.save_devices(vec![bob_device]).await;
bob_store bob_store.save_devices(vec![alice_device.clone()]).await;
.save_devices(&[alice_device.clone()])
.await
.unwrap();
let machine = VerificationMachine::new(alice, Arc::new(Box::new(store))); let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(bob_store));
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let machine = VerificationMachine::new(alice, identity, Arc::new(Box::new(store)));
let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store, None); let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store, None);
machine machine
.receive_event(&mut wrap_any_to_device_content( .receive_event(&mut wrap_any_to_device_content(
@ -285,8 +290,9 @@ mod test {
#[test] #[test]
fn create() { fn create() {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let store = MemoryStore::new(); let store = MemoryStore::new();
let _ = VerificationMachine::new(alice, Arc::new(Box::new(store))); let _ = VerificationMachine::new(alice, identity, Arc::new(Box::new(store)));
} }
#[tokio::test] #[tokio::test]

View File

@ -34,7 +34,7 @@ use matrix_sdk_common::{
use crate::{ use crate::{
identities::{LocalTrust, ReadOnlyDevice, UserIdentities}, identities::{LocalTrust, ReadOnlyDevice, UserIdentities},
store::{CryptoStore, CryptoStoreError}, store::{Changes, CryptoStore, CryptoStoreError, DeviceChanges},
ReadOnlyAccount, ToDeviceRequest, ReadOnlyAccount, ToDeviceRequest,
}; };
@ -189,34 +189,64 @@ impl Sas {
(content, guard.is_done()) (content, guard.is_done())
}; };
if done { let cancel = if done {
// TODO move the logic that marks and stores the device into the self.mark_as_done().await?
// else branch and only after the identity was verified as well. We
// dont' want to verify one without the other.
if !self.mark_device_as_verified().await? {
return Ok(self.cancel());
} else { } else {
self.mark_identity_as_verified().await?; None
} };
}
if cancel.is_some() {
Ok(cancel)
} else {
Ok(content.map(|c| { Ok(content.map(|c| {
let content = AnyToDeviceEventContent::KeyVerificationMac(c); let content = AnyToDeviceEventContent::KeyVerificationMac(c);
self.content_to_request(content) self.content_to_request(content)
})) }))
} }
}
pub(crate) async fn mark_identity_as_verified(&self) -> Result<bool, CryptoStoreError> { pub(crate) async fn mark_as_done(&self) -> Result<Option<ToDeviceRequest>, CryptoStoreError> {
if let Some(device) = self.mark_device_as_verified().await? {
let identity = self.mark_identity_as_verified().await?;
let mut changes = Changes {
devices: DeviceChanges {
changed: vec![device],
..Default::default()
},
..Default::default()
};
if let Some(i) = identity {
changes.identities.changed.push(i);
}
self.store.save_changes(changes).await?;
Ok(None)
} else {
Ok(self.cancel())
}
}
pub(crate) async fn mark_identity_as_verified(
&self,
) -> Result<Option<UserIdentities>, CryptoStoreError> {
// If there wasn't an identity available during the verification flow // If there wasn't an identity available during the verification flow
// return early as there's nothing to do. // return early as there's nothing to do.
if self.other_identity.is_none() { if self.other_identity.is_none() {
return Ok(false); return Ok(None);
} }
// TODO signal an error, e.g. when the identity got deleted so we don't
// verify/save the device either.
let identity = self.store.get_user_identity(self.other_user_id()).await?; let identity = self.store.get_user_identity(self.other_user_id()).await?;
if let Some(identity) = identity { if let Some(identity) = identity {
if identity.master_key() == self.other_identity.as_ref().unwrap().master_key() { if self
.other_identity
.as_ref()
.map_or(false, |i| i.master_key() == identity.master_key())
{
if self if self
.verified_identities() .verified_identities()
.map_or(false, |i| i.contains(&identity)) .map_or(false, |i| i.contains(&identity))
@ -228,13 +258,12 @@ impl Sas {
if let UserIdentities::Own(i) = &identity { if let UserIdentities::Own(i) = &identity {
i.mark_as_verified(); i.mark_as_verified();
self.store.save_user_identities(&[identity]).await?;
} }
// TODO if we have the private part of the user signing // TODO if we have the private part of the user signing
// key we should sign and upload a signature for this // key we should sign and upload a signature for this
// identity. // identity.
Ok(true) Ok(Some(identity))
} else { } else {
info!( info!(
"The interactive verification process didn't contain a \ "The interactive verification process didn't contain a \
@ -243,7 +272,7 @@ impl Sas {
self.verified_identities(), self.verified_identities(),
); );
Ok(false) Ok(None)
} }
} else { } else {
warn!( warn!(
@ -252,7 +281,7 @@ impl Sas {
identity.user_id(), identity.user_id(),
); );
Ok(false) Ok(None)
} }
} else { } else {
info!( info!(
@ -260,11 +289,13 @@ impl Sas {
verification was going on.", verification was going on.",
self.other_user_id(), self.other_user_id(),
); );
Ok(false) Ok(None)
} }
} }
pub(crate) async fn mark_device_as_verified(&self) -> Result<bool, CryptoStoreError> { pub(crate) async fn mark_device_as_verified(
&self,
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self let device = self
.store .store
.get_device(self.other_user_id(), self.other_device_id()) .get_device(self.other_user_id(), self.other_device_id())
@ -283,12 +314,11 @@ impl Sas {
); );
device.set_trust_state(LocalTrust::Verified); device.set_trust_state(LocalTrust::Verified);
self.store.save_devices(&[device]).await?;
// TODO if this is a device from our own user and we have // TODO if this is a device from our own user and we have
// the private part of the self signing key, we should sign // the private part of the self signing key, we should sign
// the device and upload the signature. // the device and upload the signature.
Ok(true) Ok(Some(device))
} else { } else {
info!( info!(
"The interactive verification process didn't contain a \ "The interactive verification process didn't contain a \
@ -297,7 +327,7 @@ impl Sas {
device.device_id() device.device_id()
); );
Ok(false) Ok(None)
} }
} else { } else {
warn!( warn!(
@ -306,7 +336,7 @@ impl Sas {
device.user_id(), device.user_id(),
device.device_id() device.device_id()
); );
Ok(false) Ok(None)
} }
} else { } else {
let device = self.other_device(); let device = self.other_device();
@ -317,7 +347,7 @@ impl Sas {
device.user_id(), device.user_id(),
device.device_id() device.device_id()
); );
Ok(false) Ok(None)
} }
} }
@ -777,12 +807,11 @@ mod test {
let bob_device = ReadOnlyDevice::from_account(&bob).await; let bob_device = ReadOnlyDevice::from_account(&bob).await;
let alice_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let alice_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let bob_store = MemoryStore::new();
bob_store bob_store.save_devices(vec![alice_device.clone()]).await;
.save_devices(&[alice_device.clone()])
.await let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(bob_store));
.unwrap();
let (alice, content) = Sas::start(alice, bob_device, alice_store, None); let (alice, content) = Sas::start(alice, bob_device, alice_store, None);
let event = wrap_to_device_event(alice.user_id(), content); let event = wrap_to_device_event(alice.user_id(), content);

View File

@ -11,7 +11,7 @@ repository = "https://github.com/matrix-org/matrix-rust-sdk"
version = "0.1.0" version = "0.1.0"
[dependencies] [dependencies]
serde_json = "1.0.58" serde_json = "1.0.59"
http = "0.2.1" http = "0.2.1"
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" } matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" }