Merge branch 'master' into room-state-getters

master
Damir Jelić 2021-07-27 11:18:29 +02:00
commit aa5f532f86
82 changed files with 4859 additions and 1989 deletions

2
.cargo/config.toml Normal file
View File

@ -0,0 +1,2 @@
[doc.extern-map.registries]
crates-io = "https://docs.rs/"

View File

@ -94,19 +94,10 @@ jobs:
strategy: strategy:
matrix: matrix:
name: name:
- linux / appservice / stable / actix
- macOS / appservice / stable / actix
- linux / appservice / stable / warp - linux / appservice / stable / warp
- macOS / appservice / stable / warp - macOS / appservice / stable / warp
include: include:
- name: linux / appservice / stable / actix
cargo_args: --no-default-features --features actix
- name: macOS / appservice / stable / actix
os: macOS-latest
cargo_args: --no-default-features --features actix
- name: linux / appservice / stable / warp - name: linux / appservice / stable / warp
cargo_args: --features warp cargo_args: --features warp

View File

@ -27,7 +27,7 @@ jobs:
RUSTDOCFLAGS: "--enable-index-page -Zunstable-options" RUSTDOCFLAGS: "--enable-index-page -Zunstable-options"
with: with:
command: doc command: doc
args: --no-deps --workspace --features docs args: --no-deps --workspace --features docs -Zrustdoc-map
- name: Deploy docs - name: Deploy docs
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}

View File

@ -1,8 +1,14 @@
[package] [package]
name = "matrix-qrcode" name = "matrix-qrcode"
description = "Library to encode and decode QR codes for interactive verifications in Matrix land"
version = "0.1.0" version = "0.1.0"
authors = ["Damir Jelić <poljar@termina.org.uk>"] authors = ["Damir Jelić <poljar@termina.org.uk>"]
edition = "2018" edition = "2018"
homepage = "https://github.com/matrix-org/matrix-rust-sdk"
keywords = ["matrix", "chat", "messaging", "ruma", "nio"]
license = "Apache-2.0"
readme = "README.md"
repository = "https://github.com/matrix-org/matrix-rust-sdk"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["docs"] features = ["docs"]
@ -19,6 +25,6 @@ base64 = "0.13.0"
byteorder = "1.4.3" byteorder = "1.4.3"
image = { version = "0.23.14", optional = true } image = { version = "0.23.14", optional = true }
qrcode = { version = "0.12.0", default-features = false } qrcode = { version = "0.12.0", default-features = false }
rqrr = { version = "0.3.2" , optional = true } rqrr = { version = "0.3.2", optional = true }
ruma-identifiers = "0.19.1" ruma-identifiers = "0.19.3"
thiserror = "1.0.24" thiserror = "1.0.25"

62
matrix_qrcode/README.md Normal file
View File

@ -0,0 +1,62 @@
[![Build Status](https://img.shields.io/travis/matrix-org/matrix-rust-sdk.svg?style=flat-square)](https://travis-ci.org/matrix-org/matrix-rust-sdk)
[![codecov](https://img.shields.io/codecov/c/github/matrix-org/matrix-rust-sdk/master.svg?style=flat-square)](https://codecov.io/gh/matrix-org/matrix-rust-sdk)
[![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg?style=flat-square)](https://opensource.org/licenses/Apache-2.0)
[![#matrix-rust-sdk](https://img.shields.io/badge/matrix-%23matrix--rust--sdk-blue?style=flat-square)](https://matrix.to/#/#matrix-rust-sdk:matrix.org)
# matrix-qrcode
**matrix-qrcode** is a crate to easily generate and parse QR codes for
interactive verification using [QR codes] in Matrix.
[Matrix]: https://matrix.org/
[Rust]: https://www.rust-lang.org/
[QR codes]: https://spec.matrix.org/unstable/client-server-api/#qr-codes
## Usage
This is probably not the crate you are looking for, it's used internally in the
matrix-rust-sdk.
If you still want to play with QR codes, here are a couple of helpful examples.
### Decode an image
```rust
use image;
use matrix_qrcode::{QrVerificationData, DecodingError};
fn main() -> Result<(), DecodingError> {
let image = image::open("/path/to/my/image.png").unwrap();
let result = QrVerificationData::from_image(image)?;
Ok(())
}
```
### Encode into a QR code
```rust
use matrix_qrcode::{QrVerificationData, DecodingError};
use image::Luma;
fn main() -> Result<(), DecodingError> {
let data = b"MATRIX\
\x02\x02\x00\x07\
FLOW_ID\
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
SHARED_SECRET";
let data = QrVerificationData::from_bytes(data)?;
let encoded = data.to_qr_code().unwrap();
let image = encoded.render::<Luma<u8>>().build();
Ok(())
}
```
## License
[Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)

View File

@ -20,12 +20,12 @@
//! [spec]: https://spec.matrix.org/unstable/client-server-api/#qr-code-format //! [spec]: https://spec.matrix.org/unstable/client-server-api/#qr-code-format
//! //!
//! ```no_run //! ```no_run
//! # use matrix_qrcode::{QrVerification, DecodingError}; //! # use matrix_qrcode::{QrVerificationData, DecodingError};
//! # fn main() -> Result<(), DecodingError> { //! # fn main() -> Result<(), DecodingError> {
//! use image; //! use image;
//! //!
//! let image = image::open("/path/to/my/image.png").unwrap(); //! let image = image::open("/path/to/my/image.png").unwrap();
//! let result = QrVerification::from_image(image)?; //! let result = QrVerificationData::from_image(image)?;
//! # Ok(()) //! # Ok(())
//! # } //! # }
//! ``` //! ```
@ -55,7 +55,7 @@ pub use qrcode;
#[cfg_attr(feature = "docs", doc(cfg(decode_image)))] #[cfg_attr(feature = "docs", doc(cfg(decode_image)))]
pub use rqrr; pub use rqrr;
pub use types::{ pub use types::{
QrVerification, SelfVerificationData, SelfVerificationNoMasterKey, VerificationData, QrVerificationData, SelfVerificationData, SelfVerificationNoMasterKey, VerificationData,
}; };
#[cfg(test)] #[cfg(test)]
@ -70,7 +70,7 @@ mod test {
#[cfg(feature = "decode_image")] #[cfg(feature = "decode_image")]
use crate::utils::decode_qr; use crate::utils::decode_qr;
use crate::{DecodingError, QrVerification}; use crate::{DecodingError, QrVerificationData};
#[cfg(feature = "decode_image")] #[cfg(feature = "decode_image")]
static VERIFICATION: &[u8; 4277] = include_bytes!("../data/verification.png"); static VERIFICATION: &[u8; 4277] = include_bytes!("../data/verification.png");
@ -92,9 +92,9 @@ mod test {
fn decode_test() { fn decode_test() {
let image = Cursor::new(VERIFICATION); let image = Cursor::new(VERIFICATION);
let image = image::load(image, ImageFormat::Png).unwrap().to_luma8(); let image = image::load(image, ImageFormat::Png).unwrap().to_luma8();
let result = QrVerification::try_from(image).unwrap(); let result = QrVerificationData::try_from(image).unwrap();
assert!(matches!(result, QrVerification::Verification(_))); assert!(matches!(result, QrVerificationData::Verification(_)));
} }
#[test] #[test]
@ -102,18 +102,18 @@ mod test {
fn decode_encode_cycle() { fn decode_encode_cycle() {
let image = Cursor::new(VERIFICATION); let image = Cursor::new(VERIFICATION);
let image = image::load(image, ImageFormat::Png).unwrap(); let image = image::load(image, ImageFormat::Png).unwrap();
let result = QrVerification::from_image(image).unwrap(); let result = QrVerificationData::from_image(image).unwrap();
assert!(matches!(result, QrVerification::Verification(_))); assert!(matches!(result, QrVerificationData::Verification(_)));
let encoded = result.to_qr_code().unwrap(); let encoded = result.to_qr_code().unwrap();
let image = encoded.render::<Luma<u8>>().build(); let image = encoded.render::<Luma<u8>>().build();
let second_result = QrVerification::try_from(image).unwrap(); let second_result = QrVerificationData::try_from(image).unwrap();
assert_eq!(result, second_result); assert_eq!(result, second_result);
let bytes = result.to_bytes().unwrap(); let bytes = result.to_bytes().unwrap();
let third_result = QrVerification::from_bytes(bytes).unwrap(); let third_result = QrVerificationData::from_bytes(bytes).unwrap();
assert_eq!(result, third_result); assert_eq!(result, third_result);
} }
@ -123,18 +123,18 @@ mod test {
fn decode_encode_cycle_self() { fn decode_encode_cycle_self() {
let image = Cursor::new(SELF_VERIFICATION); let image = Cursor::new(SELF_VERIFICATION);
let image = image::load(image, ImageFormat::Png).unwrap(); let image = image::load(image, ImageFormat::Png).unwrap();
let result = QrVerification::try_from(image).unwrap(); let result = QrVerificationData::try_from(image).unwrap();
assert!(matches!(result, QrVerification::SelfVerification(_))); assert!(matches!(result, QrVerificationData::SelfVerification(_)));
let encoded = result.to_qr_code().unwrap(); let encoded = result.to_qr_code().unwrap();
let image = encoded.render::<Luma<u8>>().build(); let image = encoded.render::<Luma<u8>>().build();
let second_result = QrVerification::from_luma(image).unwrap(); let second_result = QrVerificationData::from_luma(image).unwrap();
assert_eq!(result, second_result); assert_eq!(result, second_result);
let bytes = result.to_bytes().unwrap(); let bytes = result.to_bytes().unwrap();
let third_result = QrVerification::from_bytes(bytes).unwrap(); let third_result = QrVerificationData::from_bytes(bytes).unwrap();
assert_eq!(result, third_result); assert_eq!(result, third_result);
} }
@ -144,18 +144,18 @@ mod test {
fn decode_encode_cycle_self_no_master() { fn decode_encode_cycle_self_no_master() {
let image = Cursor::new(SELF_NO_MASTER); let image = Cursor::new(SELF_NO_MASTER);
let image = image::load(image, ImageFormat::Png).unwrap(); let image = image::load(image, ImageFormat::Png).unwrap();
let result = QrVerification::from_image(image).unwrap(); let result = QrVerificationData::from_image(image).unwrap();
assert!(matches!(result, QrVerification::SelfVerificationNoMasterKey(_))); assert!(matches!(result, QrVerificationData::SelfVerificationNoMasterKey(_)));
let encoded = result.to_qr_code().unwrap(); let encoded = result.to_qr_code().unwrap();
let image = encoded.render::<Luma<u8>>().build(); let image = encoded.render::<Luma<u8>>().build();
let second_result = QrVerification::try_from(image).unwrap(); let second_result = QrVerificationData::try_from(image).unwrap();
assert_eq!(result, second_result); assert_eq!(result, second_result);
let bytes = result.to_bytes().unwrap(); let bytes = result.to_bytes().unwrap();
let third_result = QrVerification::try_from(bytes).unwrap(); let third_result = QrVerificationData::try_from(bytes).unwrap();
assert_eq!(result, third_result); assert_eq!(result, third_result);
} }
@ -165,35 +165,35 @@ mod test {
fn decode_invalid_qr() { fn decode_invalid_qr() {
let qr = QrCode::new(b"NonMatrixCode").expect("Can't build a simple QR code"); let qr = QrCode::new(b"NonMatrixCode").expect("Can't build a simple QR code");
let image = qr.render::<Luma<u8>>().build(); let image = qr.render::<Luma<u8>>().build();
let result = QrVerification::try_from(image); let result = QrVerificationData::try_from(image);
assert!(matches!(result, Err(DecodingError::Header))) assert!(matches!(result, Err(DecodingError::Header)))
} }
#[test] #[test]
fn decode_invalid_header() { fn decode_invalid_header() {
let data = b"NonMatrixCode"; let data = b"NonMatrixCode";
let result = QrVerification::from_bytes(data); let result = QrVerificationData::from_bytes(data);
assert!(matches!(result, Err(DecodingError::Header))) assert!(matches!(result, Err(DecodingError::Header)))
} }
#[test] #[test]
fn decode_invalid_mode() { fn decode_invalid_mode() {
let data = b"MATRIX\x02\x03"; let data = b"MATRIX\x02\x03";
let result = QrVerification::from_bytes(data); let result = QrVerificationData::from_bytes(data);
assert!(matches!(result, Err(DecodingError::Mode(3)))) assert!(matches!(result, Err(DecodingError::Mode(3))))
} }
#[test] #[test]
fn decode_invalid_version() { fn decode_invalid_version() {
let data = b"MATRIX\x01\x03"; let data = b"MATRIX\x01\x03";
let result = QrVerification::from_bytes(data); let result = QrVerificationData::from_bytes(data);
assert!(matches!(result, Err(DecodingError::Version(1)))) assert!(matches!(result, Err(DecodingError::Version(1))))
} }
#[test] #[test]
fn decode_missing_data() { fn decode_missing_data() {
let data = b"MATRIX\x02\x02"; let data = b"MATRIX\x02\x02";
let result = QrVerification::from_bytes(data); let result = QrVerificationData::from_bytes(data);
assert!(matches!(result, Err(DecodingError::Read(_)))) assert!(matches!(result, Err(DecodingError::Read(_))))
} }
@ -206,7 +206,7 @@ mod test {
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
SECRET"; SECRET";
let result = QrVerification::from_bytes(data); let result = QrVerificationData::from_bytes(data);
assert!(matches!(result, Err(DecodingError::SharedSecret(_)))) assert!(matches!(result, Err(DecodingError::SharedSecret(_))))
} }
@ -219,7 +219,7 @@ mod test {
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
SECRETISLONGENOUGH"; SECRETISLONGENOUGH";
let result = QrVerification::from_bytes(data); let result = QrVerificationData::from_bytes(data);
assert!(matches!(result, Err(DecodingError::Identifier(_)))) assert!(matches!(result, Err(DecodingError::Identifier(_))))
} }
} }

View File

@ -33,7 +33,7 @@ use crate::{
/// An enum representing the different modes a QR verification can be in. /// An enum representing the different modes a QR verification can be in.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum QrVerification { pub enum QrVerificationData {
/// The QR verification is verifying another user /// The QR verification is verifying another user
Verification(VerificationData), Verification(VerificationData),
/// The QR verification is self-verifying and the current device trusts or /// The QR verification is self-verifying and the current device trusts or
@ -46,7 +46,7 @@ pub enum QrVerification {
#[cfg(feature = "decode_image")] #[cfg(feature = "decode_image")]
#[cfg_attr(feature = "docs", doc(cfg(decode_image)))] #[cfg_attr(feature = "docs", doc(cfg(decode_image)))]
impl TryFrom<DynamicImage> for QrVerification { impl TryFrom<DynamicImage> for QrVerificationData {
type Error = DecodingError; type Error = DecodingError;
fn try_from(image: DynamicImage) -> Result<Self, Self::Error> { fn try_from(image: DynamicImage) -> Result<Self, Self::Error> {
@ -56,7 +56,7 @@ impl TryFrom<DynamicImage> for QrVerification {
#[cfg(feature = "decode_image")] #[cfg(feature = "decode_image")]
#[cfg_attr(feature = "docs", doc(cfg(decode_image)))] #[cfg_attr(feature = "docs", doc(cfg(decode_image)))]
impl TryFrom<ImageBuffer<Luma<u8>, Vec<u8>>> for QrVerification { impl TryFrom<ImageBuffer<Luma<u8>, Vec<u8>>> for QrVerificationData {
type Error = DecodingError; type Error = DecodingError;
fn try_from(image: ImageBuffer<Luma<u8>, Vec<u8>>) -> Result<Self, Self::Error> { fn try_from(image: ImageBuffer<Luma<u8>, Vec<u8>>) -> Result<Self, Self::Error> {
@ -64,7 +64,7 @@ impl TryFrom<ImageBuffer<Luma<u8>, Vec<u8>>> for QrVerification {
} }
} }
impl TryFrom<&[u8]> for QrVerification { impl TryFrom<&[u8]> for QrVerificationData {
type Error = DecodingError; type Error = DecodingError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> { fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
@ -72,7 +72,7 @@ impl TryFrom<&[u8]> for QrVerification {
} }
} }
impl TryFrom<Vec<u8>> for QrVerification { impl TryFrom<Vec<u8>> for QrVerificationData {
type Error = DecodingError; type Error = DecodingError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> { fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
@ -80,8 +80,8 @@ impl TryFrom<Vec<u8>> for QrVerification {
} }
} }
impl QrVerification { impl QrVerificationData {
/// Decode and parse an image of a QR code into a `QrVerification` /// Decode and parse an image of a QR code into a `QrVerificationData`
/// ///
/// The image will be converted into a grey scale image before decoding is /// The image will be converted into a grey scale image before decoding is
/// attempted /// attempted
@ -92,12 +92,12 @@ impl QrVerification {
/// ///
/// # Example /// # Example
/// ```no_run /// ```no_run
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// use image; /// use image;
/// ///
/// let image = image::open("/path/to/my/image.png").unwrap(); /// let image = image::open("/path/to/my/image.png").unwrap();
/// let result = QrVerification::from_image(image)?; /// let result = QrVerificationData::from_image(image)?;
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
@ -109,7 +109,7 @@ impl QrVerification {
} }
/// Decode and parse an grey scale image of a QR code into a /// Decode and parse an grey scale image of a QR code into a
/// `QrVerification` /// `QrVerificationData`
/// ///
/// # Arguments /// # Arguments
/// ///
@ -117,13 +117,13 @@ impl QrVerification {
/// ///
/// # Example /// # Example
/// ```no_run /// ```no_run
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// use image; /// use image;
/// ///
/// let image = image::open("/path/to/my/image.png").unwrap(); /// let image = image::open("/path/to/my/image.png").unwrap();
/// let image = image.to_luma8(); /// let image = image.to_luma8();
/// let result = QrVerification::from_luma(image)?; /// let result = QrVerificationData::from_luma(image)?;
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
@ -134,7 +134,7 @@ impl QrVerification {
} }
/// Parse the decoded payload of a QR code in byte slice form as a /// Parse the decoded payload of a QR code in byte slice form as a
/// `QrVerification` /// `QrVerificationData`
/// ///
/// This method is useful if you would like to do your own custom QR code /// This method is useful if you would like to do your own custom QR code
/// decoding. /// decoding.
@ -145,7 +145,7 @@ impl QrVerification {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// let data = b"MATRIX\ /// let data = b"MATRIX\
/// \x02\x02\x00\x07\ /// \x02\x02\x00\x07\
@ -154,7 +154,7 @@ impl QrVerification {
/// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ /// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
/// SHARED_SECRET"; /// SHARED_SECRET";
/// ///
/// let result = QrVerification::from_bytes(data)?; /// let result = QrVerificationData::from_bytes(data)?;
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
@ -162,9 +162,9 @@ impl QrVerification {
Self::decode_bytes(bytes) Self::decode_bytes(bytes)
} }
/// Encode the `QrVerification` into a `QrCode`. /// Encode the `QrVerificationData` into a `QrCode`.
/// ///
/// This method turns the `QrVerification` into a QR code that can be /// This method turns the `QrVerificationData` into a QR code that can be
/// rendered and presented to be scanned. /// rendered and presented to be scanned.
/// ///
/// The encoding can fail if the data doesn't fit into a QR code or if the /// The encoding can fail if the data doesn't fit into a QR code or if the
@ -173,7 +173,7 @@ impl QrVerification {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// let data = b"MATRIX\ /// let data = b"MATRIX\
/// \x02\x02\x00\x07\ /// \x02\x02\x00\x07\
@ -182,28 +182,28 @@ impl QrVerification {
/// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ /// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
/// SHARED_SECRET"; /// SHARED_SECRET";
/// ///
/// let result = QrVerification::from_bytes(data)?; /// let result = QrVerificationData::from_bytes(data)?;
/// let encoded = result.to_qr_code().unwrap(); /// let encoded = result.to_qr_code().unwrap();
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub fn to_qr_code(&self) -> Result<QrCode, EncodingError> { pub fn to_qr_code(&self) -> Result<QrCode, EncodingError> {
match self { match self {
QrVerification::Verification(v) => v.to_qr_code(), QrVerificationData::Verification(v) => v.to_qr_code(),
QrVerification::SelfVerification(v) => v.to_qr_code(), QrVerificationData::SelfVerification(v) => v.to_qr_code(),
QrVerification::SelfVerificationNoMasterKey(v) => v.to_qr_code(), QrVerificationData::SelfVerificationNoMasterKey(v) => v.to_qr_code(),
} }
} }
/// Encode the `QrVerification` into a vector of bytes that can be encoded /// Encode the `QrVerificationData` into a vector of bytes that can be
/// as a QR code. /// encoded as a QR code.
/// ///
/// The encoding can fail if the identity keys that should be encoded are /// The encoding can fail if the identity keys that should be encoded are
/// not valid base64. /// not valid base64.
/// ///
/// # Example /// # Example
/// ``` /// ```
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// let data = b"MATRIX\ /// let data = b"MATRIX\
/// \x02\x02\x00\x07\ /// \x02\x02\x00\x07\
@ -212,7 +212,7 @@ impl QrVerification {
/// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ /// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
/// SHARED_SECRET"; /// SHARED_SECRET";
/// ///
/// let result = QrVerification::from_bytes(data)?; /// let result = QrVerificationData::from_bytes(data)?;
/// let encoded = result.to_bytes().unwrap(); /// let encoded = result.to_bytes().unwrap();
/// ///
/// assert_eq!(data.as_ref(), encoded.as_slice()); /// assert_eq!(data.as_ref(), encoded.as_slice());
@ -221,9 +221,9 @@ impl QrVerification {
/// ``` /// ```
pub fn to_bytes(&self) -> Result<Vec<u8>, EncodingError> { pub fn to_bytes(&self) -> Result<Vec<u8>, EncodingError> {
match self { match self {
QrVerification::Verification(v) => v.to_bytes(), QrVerificationData::Verification(v) => v.to_bytes(),
QrVerification::SelfVerification(v) => v.to_bytes(), QrVerificationData::SelfVerification(v) => v.to_bytes(),
QrVerification::SelfVerificationNoMasterKey(v) => v.to_bytes(), QrVerificationData::SelfVerificationNoMasterKey(v) => v.to_bytes(),
} }
} }
@ -289,13 +289,13 @@ impl QrVerification {
return Err(DecodingError::SharedSecret(shared_secret.len())); return Err(DecodingError::SharedSecret(shared_secret.len()));
} }
QrVerification::new(mode, flow_id, first_key, second_key, shared_secret) QrVerificationData::new(mode, flow_id, first_key, second_key, shared_secret)
} }
/// Decode the given image of an QR code and if we find a valid code, try to /// Decode the given image of an QR code and if we find a valid code, try to
/// decode it as a `QrVerification`. /// decode it as a `QrVerification`.
#[cfg(feature = "decode_image")] #[cfg(feature = "decode_image")]
fn decode(image: ImageBuffer<Luma<u8>, Vec<u8>>) -> Result<QrVerification, DecodingError> { fn decode(image: ImageBuffer<Luma<u8>, Vec<u8>>) -> Result<QrVerificationData, DecodingError> {
let decoded = decode_qr(image)?; let decoded = decode_qr(image)?;
Self::decode_bytes(decoded) Self::decode_bytes(decoded)
} }
@ -328,41 +328,41 @@ impl QrVerification {
} }
} }
/// Get the flow id for this `QrVerification`. /// Get the flow id for this `QrVerificationData`.
/// ///
/// This represents the ID as a string even if it is a `EventId`. /// This represents the ID as a string even if it is a `EventId`.
pub fn flow_id(&self) -> &str { pub fn flow_id(&self) -> &str {
match self { match self {
QrVerification::Verification(v) => v.event_id.as_str(), QrVerificationData::Verification(v) => v.event_id.as_str(),
QrVerification::SelfVerification(v) => &v.transaction_id, QrVerificationData::SelfVerification(v) => &v.transaction_id,
QrVerification::SelfVerificationNoMasterKey(v) => &v.transaction_id, QrVerificationData::SelfVerificationNoMasterKey(v) => &v.transaction_id,
} }
} }
/// Get the first key of this `QrVerification`. /// Get the first key of this `QrVerificationData`.
pub fn first_key(&self) -> &str { pub fn first_key(&self) -> &str {
match self { match self {
QrVerification::Verification(v) => &v.first_master_key, QrVerificationData::Verification(v) => &v.first_master_key,
QrVerification::SelfVerification(v) => &v.master_key, QrVerificationData::SelfVerification(v) => &v.master_key,
QrVerification::SelfVerificationNoMasterKey(v) => &v.device_key, QrVerificationData::SelfVerificationNoMasterKey(v) => &v.device_key,
} }
} }
/// Get the second key of this `QrVerification`. /// Get the second key of this `QrVerificationData`.
pub fn second_key(&self) -> &str { pub fn second_key(&self) -> &str {
match self { match self {
QrVerification::Verification(v) => &v.second_master_key, QrVerificationData::Verification(v) => &v.second_master_key,
QrVerification::SelfVerification(v) => &v.device_key, QrVerificationData::SelfVerification(v) => &v.device_key,
QrVerification::SelfVerificationNoMasterKey(v) => &v.master_key, QrVerificationData::SelfVerificationNoMasterKey(v) => &v.master_key,
} }
} }
/// Get the secret of this `QrVerification`. /// Get the secret of this `QrVerificationData`.
pub fn secret(&self) -> &str { pub fn secret(&self) -> &str {
match self { match self {
QrVerification::Verification(v) => &v.shared_secret, QrVerificationData::Verification(v) => &v.shared_secret,
QrVerification::SelfVerification(v) => &v.shared_secret, QrVerificationData::SelfVerification(v) => &v.shared_secret,
QrVerification::SelfVerificationNoMasterKey(v) => &v.shared_secret, QrVerificationData::SelfVerificationNoMasterKey(v) => &v.shared_secret,
} }
} }
} }
@ -412,7 +412,7 @@ impl VerificationData {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// let data = b"MATRIX\ /// let data = b"MATRIX\
/// \x02\x00\x00\x0f\ /// \x02\x00\x00\x0f\
@ -421,8 +421,8 @@ impl VerificationData {
/// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ /// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
/// SHARED_SECRET"; /// SHARED_SECRET";
/// ///
/// let result = QrVerification::from_bytes(data)?; /// let result = QrVerificationData::from_bytes(data)?;
/// if let QrVerification::Verification(decoded) = result { /// if let QrVerificationData::Verification(decoded) = result {
/// let encoded = decoded.to_bytes().unwrap(); /// let encoded = decoded.to_bytes().unwrap();
/// assert_eq!(data.as_ref(), encoded.as_slice()); /// assert_eq!(data.as_ref(), encoded.as_slice());
/// } else { /// } else {
@ -459,7 +459,7 @@ impl VerificationData {
} }
} }
impl From<VerificationData> for QrVerification { impl From<VerificationData> for QrVerificationData {
fn from(data: VerificationData) -> Self { fn from(data: VerificationData) -> Self {
Self::Verification(data) Self::Verification(data)
} }
@ -515,7 +515,7 @@ impl SelfVerificationData {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// let data = b"MATRIX\ /// let data = b"MATRIX\
/// \x02\x01\x00\x06\ /// \x02\x01\x00\x06\
@ -524,8 +524,8 @@ impl SelfVerificationData {
/// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ /// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
/// SHARED_SECRET"; /// SHARED_SECRET";
/// ///
/// let result = QrVerification::from_bytes(data)?; /// let result = QrVerificationData::from_bytes(data)?;
/// if let QrVerification::SelfVerification(decoded) = result { /// if let QrVerificationData::SelfVerification(decoded) = result {
/// let encoded = decoded.to_bytes().unwrap(); /// let encoded = decoded.to_bytes().unwrap();
/// assert_eq!(data.as_ref(), encoded.as_slice()); /// assert_eq!(data.as_ref(), encoded.as_slice());
/// } else { /// } else {
@ -562,7 +562,7 @@ impl SelfVerificationData {
} }
} }
impl From<SelfVerificationData> for QrVerification { impl From<SelfVerificationData> for QrVerificationData {
fn from(data: SelfVerificationData) -> Self { fn from(data: SelfVerificationData) -> Self {
Self::SelfVerification(data) Self::SelfVerification(data)
} }
@ -618,7 +618,7 @@ impl SelfVerificationNoMasterKey {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// # use matrix_qrcode::{QrVerification, DecodingError}; /// # use matrix_qrcode::{QrVerificationData, DecodingError};
/// # fn main() -> Result<(), DecodingError> { /// # fn main() -> Result<(), DecodingError> {
/// let data = b"MATRIX\ /// let data = b"MATRIX\
/// \x02\x02\x00\x06\ /// \x02\x02\x00\x06\
@ -627,8 +627,8 @@ impl SelfVerificationNoMasterKey {
/// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\ /// BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\
/// SHARED_SECRET"; /// SHARED_SECRET";
/// ///
/// let result = QrVerification::from_bytes(data)?; /// let result = QrVerificationData::from_bytes(data)?;
/// if let QrVerification::SelfVerificationNoMasterKey(decoded) = result { /// if let QrVerificationData::SelfVerificationNoMasterKey(decoded) = result {
/// let encoded = decoded.to_bytes().unwrap(); /// let encoded = decoded.to_bytes().unwrap();
/// assert_eq!(data.as_ref(), encoded.as_slice()); /// assert_eq!(data.as_ref(), encoded.as_slice());
/// } else { /// } else {
@ -665,7 +665,7 @@ impl SelfVerificationNoMasterKey {
} }
} }
impl From<SelfVerificationNoMasterKey> for QrVerification { impl From<SelfVerificationNoMasterKey> for QrVerificationData {
fn from(data: SelfVerificationNoMasterKey) -> Self { fn from(data: SelfVerificationNoMasterKey) -> Self {
Self::SelfVerificationNoMasterKey(data) Self::SelfVerificationNoMasterKey(data)
} }

View File

@ -8,7 +8,7 @@ license = "Apache-2.0"
name = "matrix-sdk" name = "matrix-sdk"
readme = "README.md" readme = "README.md"
repository = "https://github.com/matrix-org/matrix-rust-sdk" repository = "https://github.com/matrix-org/matrix-rust-sdk"
version = "0.2.0" version = "0.3.0"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["docs"] features = ["docs"]
@ -32,39 +32,39 @@ docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"]
[dependencies] [dependencies]
dashmap = "4.0.2" dashmap = "4.0.2"
futures = "0.3.12" futures = "0.3.15"
http = "0.2.3" http = "0.2.4"
serde_json = "1.0.61" serde_json = "1.0.64"
thiserror = "1.0.23" thiserror = "1.0.25"
tracing = "0.1.22" tracing = "0.1.26"
url = "2.2.0" url = "2.2.2"
zeroize = "1.2.0" zeroize = "1.3.0"
mime = "0.3.16" mime = "0.3.16"
rand = { version = "0.8.2", optional = true } rand = { version = "0.8.4", optional = true }
bytes = "1.0.1" bytes = "1.0.1"
matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" } matrix-sdk-common = { version = "0.3.0", path = "../matrix_sdk_common" }
[dependencies.matrix-sdk-base] [dependencies.matrix-sdk-base]
version = "0.2.0" version = "0.3.0"
path = "../matrix_sdk_base" path = "../matrix_sdk_base"
default_features = false default_features = false
[dependencies.reqwest] [dependencies.reqwest]
version = "0.11.0" version = "0.11.3"
default_features = false default_features = false
[dependencies.ruma] [dependencies.ruma]
version = "0.1.2" version = "0.2.0"
features = ["client-api-c", "compat", "unstable-pre-spec"] features = ["client-api-c", "compat", "unstable-pre-spec"]
[dependencies.tokio-stream] [dependencies.tokio-stream]
version = "0.1.4" version = "0.1.6"
features = ["net"] features = ["net"]
optional = true optional = true
[dependencies.warp] [dependencies.warp]
version = "0.3.0" version = "0.3.1"
default-features = false default-features = false
optional = true optional = true
@ -73,7 +73,7 @@ version = "0.3.0"
features = ["tokio"] features = ["tokio"]
[dependencies.tracing-futures] [dependencies.tracing-futures]
version = "0.2.4" version = "0.2.5"
default-features = false default-features = false
features = ["std", "std-future"] features = ["std", "std-future"]
@ -81,7 +81,7 @@ features = ["std", "std-future"]
futures-timer = "3.0.2" futures-timer = "3.0.2"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio] [target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio]
version = "1.1.0" version = "1.7.1"
default-features = false default-features = false
features = ["fs", "rt"] features = ["fs", "rt"]
@ -90,16 +90,15 @@ version = "3.0.2"
features = ["wasm-bindgen"] features = ["wasm-bindgen"]
[dev-dependencies] [dev-dependencies]
dirs = "3.0.1" dirs = "3.0.2"
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" }
matches = "0.1.8" matches = "0.1.8"
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] } matrix-sdk-test = { version = "0.3.0", path = "../matrix_sdk_test" }
serde_json = "1.0.61" tokio = { version = "1.7.1", default-features = false, features = ["rt-multi-thread", "macros"] }
tracing-subscriber = "0.2.15" serde_json = "1.0.64"
tracing-subscriber = "0.2.18"
tempfile = "3.2.0" tempfile = "3.2.0"
mockito = "0.29.0" mockito = "0.30.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" }
[[example]] [[example]]
name = "emoji_verification" name = "emoji_verification"

View File

@ -1,9 +1,9 @@
use std::{env, process::exit}; use std::{env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, async_trait,
events::{room::member::MemberEventContent, StrippedStateEvent},
room::Room, room::Room,
ruma::events::{room::member::MemberEventContent, StrippedStateEvent},
Client, ClientConfig, EventHandler, SyncSettings, Client, ClientConfig, EventHandler, SyncSettings,
}; };
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};

View File

@ -1,12 +1,12 @@
use std::{env, process::exit}; use std::{env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, async_trait,
events::{ room::Room,
ruma::events::{
room::message::{MessageEventContent, MessageType, TextMessageEventContent}, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
AnyMessageEventContent, SyncMessageEvent, AnyMessageEventContent, SyncMessageEvent,
}, },
room::Room,
Client, ClientConfig, EventHandler, SyncSettings, Client, ClientConfig, EventHandler, SyncSettings,
}; };
use url::Url; use url::Url;

View File

@ -6,7 +6,8 @@ use std::{
}; };
use matrix_sdk::{ use matrix_sdk::{
self, api::r0::uiaa::AuthData, identifiers::UserId, Client, LoopCtrl, SyncSettings, ruma::{api::client::r0::uiaa::AuthData, UserId},
Client, LoopCtrl, SyncSettings,
}; };
use serde_json::json; use serde_json::json;
use url::Url; use url::Url;

View File

@ -9,13 +9,18 @@ use std::{
use matrix_sdk::{ use matrix_sdk::{
self, self,
events::{room::message::MessageType, AnySyncMessageEvent, AnySyncRoomEvent, AnyToDeviceEvent}, ruma::{
identifiers::UserId, events::{
Client, LoopCtrl, Sas, SyncSettings, room::message::MessageType, AnySyncMessageEvent, AnySyncRoomEvent, AnyToDeviceEvent,
},
UserId,
},
verification::{SasVerification, Verification},
Client, LoopCtrl, SyncSettings,
}; };
use url::Url; use url::Url;
async fn wait_for_confirmation(client: Client, sas: Sas) { async fn wait_for_confirmation(client: Client, sas: SasVerification) {
println!("Does the emoji match: {:?}", sas.emoji()); println!("Does the emoji match: {:?}", sas.emoji());
let mut input = String::new(); let mut input = String::new();
@ -34,7 +39,7 @@ async fn wait_for_confirmation(client: Client, sas: Sas) {
} }
} }
fn print_result(sas: &Sas) { fn print_result(sas: &SasVerification) {
let device = sas.other_device(); let device = sas.other_device();
println!( println!(
@ -53,7 +58,7 @@ async fn print_devices(user_id: &UserId, client: &Client) {
" {:<10} {:<30} {:<}", " {:<10} {:<30} {:<}",
device.device_id(), device.device_id(),
device.display_name().as_deref().unwrap_or_default(), device.display_name().as_deref().unwrap_or_default(),
device.is_trusted() device.verified()
); );
} }
} }
@ -80,37 +85,35 @@ async fn login(
for event in response.to_device.events.iter().filter_map(|e| e.deserialize().ok()) { for event in response.to_device.events.iter().filter_map(|e| e.deserialize().ok()) {
match event { match event {
AnyToDeviceEvent::KeyVerificationStart(e) => { AnyToDeviceEvent::KeyVerificationStart(e) => {
let sas = client if let Some(Verification::SasV1(sas)) =
.get_verification(&e.content.transaction_id) client.get_verification(&e.sender, &e.content.transaction_id).await
.await {
.expect("Sas object wasn't created"); println!(
println!( "Starting verification with {} {}",
"Starting verification with {} {}", &sas.other_device().user_id(),
&sas.other_device().user_id(), &sas.other_device().device_id()
&sas.other_device().device_id() );
); print_devices(&e.sender, client).await;
print_devices(&e.sender, client).await; sas.accept().await.unwrap();
sas.accept().await.unwrap(); }
} }
AnyToDeviceEvent::KeyVerificationKey(e) => { AnyToDeviceEvent::KeyVerificationKey(e) => {
let sas = client if let Some(Verification::SasV1(sas)) =
.get_verification(&e.content.transaction_id) client.get_verification(&e.sender, &e.content.transaction_id).await
.await {
.expect("Sas object wasn't created"); tokio::spawn(wait_for_confirmation((*client).clone(), sas));
}
tokio::spawn(wait_for_confirmation((*client).clone(), sas));
} }
AnyToDeviceEvent::KeyVerificationMac(e) => { AnyToDeviceEvent::KeyVerificationMac(e) => {
let sas = client if let Some(Verification::SasV1(sas)) =
.get_verification(&e.content.transaction_id) client.get_verification(&e.sender, &e.content.transaction_id).await
.await {
.expect("Sas object wasn't created"); if sas.is_done() {
print_result(&sas);
if sas.is_done() { print_devices(&e.sender, client).await;
print_result(&sas); }
print_devices(&e.sender, client).await;
} }
} }
@ -129,7 +132,7 @@ async fn login(
if let MessageType::VerificationRequest(_) = &m.content.msgtype if let MessageType::VerificationRequest(_) = &m.content.msgtype
{ {
let request = client let request = client
.get_verification_request(&m.event_id) .get_verification_request(&m.sender, &m.event_id)
.await .await
.expect("Request object wasn't created"); .expect("Request object wasn't created");
@ -140,22 +143,28 @@ async fn login(
} }
} }
AnySyncMessageEvent::KeyVerificationKey(e) => { AnySyncMessageEvent::KeyVerificationKey(e) => {
let sas = client if let Some(Verification::SasV1(sas)) = client
.get_verification(e.content.relation.event_id.as_str()) .get_verification(
&e.sender,
e.content.relates_to.event_id.as_str(),
)
.await .await
.expect("Sas object wasn't created"); {
tokio::spawn(wait_for_confirmation((*client).clone(), sas));
tokio::spawn(wait_for_confirmation((*client).clone(), sas)); }
} }
AnySyncMessageEvent::KeyVerificationMac(e) => { AnySyncMessageEvent::KeyVerificationMac(e) => {
let sas = client if let Some(Verification::SasV1(sas)) = client
.get_verification(e.content.relation.event_id.as_str()) .get_verification(
&e.sender,
e.content.relates_to.event_id.as_str(),
)
.await .await
.expect("Sas object wasn't created"); {
if sas.is_done() {
if sas.is_done() { print_result(&sas);
print_result(&sas); print_devices(&e.sender, client).await;
print_devices(&e.sender, client).await; }
} }
} }
_ => (), _ => (),

View File

@ -1,9 +1,7 @@
use std::{convert::TryFrom, env, process::exit}; use std::{convert::TryFrom, env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
self, ruma::{api::client::r0::profile, MxcUri, UserId},
api::r0::profile,
identifiers::{MxcUri, UserId},
Client, Result as MatrixResult, Client, Result as MatrixResult,
}; };
use url::Url; use url::Url;

View File

@ -9,11 +9,11 @@ use std::{
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, self, async_trait,
events::{ room::Room,
ruma::events::{
room::message::{MessageEventContent, MessageType, TextMessageEventContent}, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
SyncMessageEvent, SyncMessageEvent,
}, },
room::Room,
Client, EventHandler, SyncSettings, Client, EventHandler, SyncSettings,
}; };
use tokio::sync::Mutex; use tokio::sync::Mutex;

View File

@ -2,11 +2,11 @@ use std::{env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, self, async_trait,
events::{ room::Room,
ruma::events::{
room::message::{MessageEventContent, MessageType, TextMessageEventContent}, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
SyncMessageEvent, SyncMessageEvent,
}, },
room::Room,
Client, EventHandler, SyncSettings, Client, EventHandler, SyncSettings,
}; };
use url::Url; use url::Url;

View File

@ -10,11 +10,11 @@ edition = "2018"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
url = "2.2.1" url = "2.2.2"
wasm-bindgen = { version = "0.2.72", features = ["serde-serialize"] } wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"] }
wasm-bindgen-futures = "0.4.22" wasm-bindgen-futures = "0.4.24"
console_error_panic_hook = "0.1.6" console_error_panic_hook = "0.1.6"
web-sys = { version = "0.3.49", features = ["console"] } web-sys = { version = "0.3.51", features = ["console"] }
[dependencies.matrix-sdk] [dependencies.matrix-sdk]
path = "../.." path = "../.."

View File

@ -1,10 +1,12 @@
use matrix_sdk::{ use matrix_sdk::{
deserialized_responses::SyncResponse, deserialized_responses::SyncResponse,
events::{ ruma::{
room::message::{MessageEventContent, MessageType, TextMessageEventContent}, events::{
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent,
},
RoomId,
}, },
identifiers::RoomId,
Client, LoopCtrl, SyncSettings, Client, LoopCtrl, SyncSettings,
}; };
use url::Url; use url::Url;
@ -58,7 +60,9 @@ impl WasmBot {
for (room_id, room) in response.rooms.join { for (room_id, room) in response.rooms.join {
for event in room.timeline.events { for event in room.timeline.events {
if let Ok(AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev))) = event.event.deserialize() { if let Ok(AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev))) =
event.event.deserialize()
{
self.on_room_message(&room_id, &ev).await self.on_room_message(&room_id, &ev).await
} }
} }
@ -79,19 +83,14 @@ pub async fn run() -> Result<JsValue, JsValue> {
let homeserver_url = Url::parse(&homeserver_url).unwrap(); let homeserver_url = Url::parse(&homeserver_url).unwrap();
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client client.login(username, password, None, Some("rust-sdk-wasm")).await.unwrap();
.login(username, password, None, Some("rust-sdk-wasm"))
.await
.unwrap();
let bot = WasmBot(client.clone()); let bot = WasmBot(client.clone());
client.sync_once(SyncSettings::default()).await.unwrap(); client.sync_once(SyncSettings::default()).await.unwrap();
let settings = SyncSettings::default().token(client.sync_token().await.unwrap()); let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
client client.sync_with_callback(settings, |response| bot.on_sync_response(response)).await;
.sync_with_callback(settings, |response| bot.on_sync_response(response))
.await;
Ok(JsValue::NULL) Ok(JsValue::NULL)
} }

View File

@ -13,11 +13,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
use std::path::PathBuf;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
io::{Cursor, Write}, io::{Cursor, Write},
path::PathBuf,
}; };
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use std::{ use std::{
@ -39,10 +40,12 @@ use futures_timer::Delay as sleep;
use http::HeaderValue; use http::HeaderValue;
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use http::Response; use http::Response;
#[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
use matrix_sdk_base::crypto::{decrypt_key_export, encrypt_key_export, olm::InboundGroupSession};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::{ use matrix_sdk_base::crypto::{
decrypt_key_export, encrypt_key_export, olm::InboundGroupSession, store::CryptoStoreError, store::CryptoStoreError, AttachmentDecryptor, OutgoingRequests, RoomMessageRequest,
AttachmentDecryptor, OutgoingRequests, RoomMessageRequest, ToDeviceRequest, ToDeviceRequest,
}; };
use matrix_sdk_base::{ use matrix_sdk_base::{
deserialized_responses::SyncResponse, deserialized_responses::SyncResponse,
@ -53,7 +56,7 @@ use mime::{self, Mime};
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use reqwest::header::InvalidHeaderValue; use reqwest::header::InvalidHeaderValue;
use ruma::{api::SendAccessToken, events::AnyMessageEventContent, identifiers::MxcUri}; use ruma::{api::SendAccessToken, events::AnyMessageEventContent, MxcUri};
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use tokio::{net::TcpListener, sync::oneshot}; use tokio::{net::TcpListener, sync::oneshot};
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
@ -64,7 +67,7 @@ use tracing::{error, info, instrument};
use url::Url; use url::Url;
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use warp::Filter; use warp::Filter;
#[cfg(feature = "encryption")] #[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
use zeroize::Zeroizing; use zeroize::Zeroizing;
/// Enum controlling if a loop running callbacks should continue or abort. /// Enum controlling if a loop running callbacks should continue or abort.
@ -126,8 +129,8 @@ use ruma::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use crate::{ use crate::{
device::{Device, UserDevices}, device::{Device, UserDevices},
sas::Sas, error::RoomKeyImportError,
verification_request::VerificationRequest, verification::{QrVerification, SasVerification, Verification, VerificationRequest},
}; };
use crate::{ use crate::{
error::HttpError, error::HttpError,
@ -501,7 +504,7 @@ impl RequestConfig {
/// All outgoing http requests will have a GET query key-value appended with /// All outgoing http requests will have a GET query key-value appended with
/// `user_id` being the key and the `user_id` from the `Session` being /// `user_id` being the key and the `user_id` from the `Session` being
/// the value. Will error if there's no `Session`. This is called /// the value. Will error if there's no `Session`. This is called
/// [identity assertion] in the Matrix Appservice Spec /// [identity assertion] in the Matrix Application Service Spec
/// ///
/// [identity assertion]: https://spec.matrix.org/unstable/application-service-api/#identity-assertion /// [identity assertion]: https://spec.matrix.org/unstable/application-service-api/#identity-assertion
#[cfg(feature = "appservice")] #[cfg(feature = "appservice")]
@ -572,7 +575,7 @@ impl Client {
/// # Example /// # Example
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, identifiers::UserId}; /// # use matrix_sdk::{Client, ruma::UserId};
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// let alice = UserId::try_from("@alice:example.org").unwrap(); /// let alice = UserId::try_from("@alice:example.org").unwrap();
/// # block_on(async { /// # block_on(async {
@ -781,20 +784,20 @@ impl Client {
/// Gets the avatar of the owner of the client, if set. /// Gets the avatar of the owner of the client, if set.
/// ///
/// Returns the avatar. No guarantee on the size of the image is given. /// Returns the avatar.
/// If no size is given the full-sized avatar will be returned. /// If a thumbnail is requested no guarantee on the size of the image is
/// given.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `width` - The desired width of the avatar. /// * `format` - The desired format of the avatar.
///
/// * `height` - The desired height of the avatar.
/// ///
/// # Example /// # Example
/// ```no_run /// ```no_run
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use matrix_sdk::identifiers::room_id; /// # use matrix_sdk::ruma::room_id;
/// # use matrix_sdk::media::MediaFormat;
/// # use url::Url; /// # use url::Url;
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
/// # block_on(async { /// # block_on(async {
@ -802,24 +805,15 @@ impl Client {
/// let client = Client::new(homeserver).unwrap(); /// let client = Client::new(homeserver).unwrap();
/// client.login(user, "password", None, None).await.unwrap(); /// client.login(user, "password", None, None).await.unwrap();
/// ///
/// if let Some(avatar) = client.avatar(Some(96), Some(96)).await.unwrap() { /// if let Some(avatar) = client.avatar(MediaFormat::File).await.unwrap() {
/// std::fs::write("avatar.png", avatar); /// std::fs::write("avatar.png", avatar);
/// } /// }
/// # }) /// # })
/// ``` /// ```
pub async fn avatar(&self, width: Option<u32>, height: Option<u32>) -> Result<Option<Vec<u8>>> { pub async fn avatar(&self, format: MediaFormat) -> Result<Option<Vec<u8>>> {
// TODO: try to offer the avatar from cache, requires avatar cache
if let Some(url) = self.avatar_url().await? { if let Some(url) = self.avatar_url().await? {
if let (Some(width), Some(height)) = (width, height) { let request = MediaRequest { media_type: MediaType::Uri(url), format };
let request = Ok(Some(self.get_media_content(&request, true).await?))
get_content_thumbnail::Request::from_url(&url, width.into(), height.into())?;
let response = self.send(request, None).await?;
Ok(Some(response.file))
} else {
let request = get_content::Request::from_url(&url)?;
let response = self.send(request, None).await?;
Ok(Some(response.file))
}
} else { } else {
Ok(None) Ok(None)
} }
@ -1011,8 +1005,7 @@ impl Client {
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use matrix_sdk::identifiers::DeviceId; /// # use matrix_sdk::ruma::{assign, DeviceId};
/// # use matrix_sdk::assign;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use url::Url; /// # use url::Url;
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
@ -1270,8 +1263,7 @@ impl Client {
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use matrix_sdk::identifiers::DeviceId; /// # use matrix_sdk::ruma::{assign, DeviceId};
/// # use matrix_sdk::assign;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use url::Url; /// # use url::Url;
/// # let homeserver = Url::parse("https://example.com").unwrap(); /// # let homeserver = Url::parse("https://example.com").unwrap();
@ -1350,10 +1342,13 @@ impl Client {
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use matrix_sdk::api::r0::account::register::{Request as RegistrationRequest, RegistrationKind}; /// # use matrix_sdk::ruma::{
/// # use matrix_sdk::api::r0::uiaa::AuthData; /// # api::client::r0::{
/// # use matrix_sdk::identifiers::DeviceId; /// # account::register::{Request as RegistrationRequest, RegistrationKind},
/// # use matrix_sdk::assign; /// # uiaa::AuthData,
/// # },
/// # assign, DeviceId,
/// # };
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use url::Url; /// # use url::Url;
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
@ -1403,7 +1398,7 @@ impl Client {
/// ```no_run /// ```no_run
/// # use matrix_sdk::{ /// # use matrix_sdk::{
/// # Client, SyncSettings, /// # Client, SyncSettings,
/// # api::r0::{ /// # ruma::api::client::r0::{
/// # filter::{ /// # filter::{
/// # FilterDefinition, LazyLoadOptions, RoomEventFilter, RoomFilter, /// # FilterDefinition, LazyLoadOptions, RoomEventFilter, RoomFilter,
/// # }, /// # },
@ -1547,7 +1542,10 @@ impl Client {
/// # Examples /// # Examples
/// ```no_run /// ```no_run
/// use matrix_sdk::Client; /// use matrix_sdk::Client;
/// # use matrix_sdk::api::r0::room::{create_room::Request as CreateRoomRequest, Visibility}; /// # use matrix_sdk::ruma::api::client::r0::room::{
/// # create_room::Request as CreateRoomRequest,
/// # Visibility,
/// # };
/// # use url::Url; /// # use url::Url;
/// ///
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
@ -1580,9 +1578,11 @@ impl Client {
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use matrix_sdk::directory::{Filter, RoomNetwork}; /// # use matrix_sdk::ruma::{
/// # use matrix_sdk::api::r0::directory::get_public_rooms_filtered::Request as PublicRoomsFilterRequest; /// # api::client::r0::directory::get_public_rooms_filtered::Request as PublicRoomsFilterRequest,
/// # use matrix_sdk::assign; /// # directory::{Filter, RoomNetwork},
/// # assign,
/// # };
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
@ -1633,7 +1633,7 @@ impl Client {
/// ///
/// ```no_run /// ```no_run
/// # use std::{path::PathBuf, fs::File, io::Read}; /// # use std::{path::PathBuf, fs::File, io::Read};
/// # use matrix_sdk::{Client, identifiers::room_id}; /// # use matrix_sdk::{Client, ruma::room_id};
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use mime; /// # use mime;
@ -1700,9 +1700,9 @@ impl Client {
/// # use matrix_sdk::{Client, SyncSettings}; /// # use matrix_sdk::{Client, SyncSettings};
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use matrix_sdk::identifiers::room_id; /// # use matrix_sdk::ruma::room_id;
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// use matrix_sdk::events::{ /// use matrix_sdk::ruma::events::{
/// AnyMessageEventContent, /// AnyMessageEventContent,
/// room::message::{MessageEventContent, TextMessageEventContent}, /// room::message::{MessageEventContent, TextMessageEventContent},
/// }; /// };
@ -1762,8 +1762,7 @@ impl Client {
/// # block_on(async { /// # block_on(async {
/// # let homeserver = Url::parse("http://localhost:8080").unwrap(); /// # let homeserver = Url::parse("http://localhost:8080").unwrap();
/// # let mut client = Client::new(homeserver).unwrap(); /// # let mut client = Client::new(homeserver).unwrap();
/// use matrix_sdk::api::r0::profile; /// use matrix_sdk::ruma::{api::client::r0::profile, user_id};
/// use matrix_sdk::identifiers::user_id;
/// ///
/// // First construct the request you want to make /// // First construct the request you want to make
/// // See https://docs.rs/ruma-client-api/latest/ruma_client_api/index.html /// // See https://docs.rs/ruma-client-api/latest/ruma_client_api/index.html
@ -1796,8 +1795,8 @@ impl Client {
request: &ToDeviceRequest, request: &ToDeviceRequest,
) -> Result<ToDeviceResponse> { ) -> Result<ToDeviceResponse> {
let txn_id_string = request.txn_id_string(); let txn_id_string = request.txn_id_string();
let request = RumaToDeviceRequest::new( let request = RumaToDeviceRequest::new_raw(
request.event_type.clone(), request.event_type.as_str(),
&txn_id_string, &txn_id_string,
request.messages.clone(), request.messages.clone(),
); );
@ -1849,8 +1848,11 @@ impl Client {
/// ///
/// ```no_run /// ```no_run
/// # use matrix_sdk::{ /// # use matrix_sdk::{
/// # api::r0::uiaa::{UiaaResponse, AuthData}, /// # ruma::api::{
/// # Client, SyncSettings, Error, FromHttpResponseError, ServerError, /// # client::r0::uiaa::{UiaaResponse, AuthData},
/// # error::{FromHttpResponseError, ServerError},
/// # },
/// # Client, Error, SyncSettings,
/// # }; /// # };
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use serde_json::json; /// # use serde_json::json;
@ -1969,7 +1971,7 @@ impl Client {
/// UI thread. /// UI thread.
/// ///
/// ```no_run /// ```no_run
/// # use matrix_sdk::events::{ /// # use matrix_sdk::ruma::events::{
/// # room::message::{MessageEvent, MessageEventContent, TextMessageEventContent}, /// # room::message::{MessageEvent, MessageEventContent, TextMessageEventContent},
/// # }; /// # };
/// # use std::sync::{Arc, RwLock}; /// # use std::sync::{Arc, RwLock};
@ -2191,26 +2193,33 @@ impl Client {
Ok(response) Ok(response)
} }
/// Get a `Sas` verification object with the given flow id. /// Get a verification object with the given flow id.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> { pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.base_client let olm = self.base_client.olm_machine().await?;
.get_verification(flow_id) olm.get_verification(user_id, flow_id).map(|v| match v {
.await matrix_sdk_base::crypto::Verification::SasV1(s) => {
.map(|sas| Sas { inner: sas, client: self.clone() }) SasVerification { inner: s, client: self.clone() }.into()
}
matrix_sdk_base::crypto::Verification::QrV1(qr) => {
QrVerification { inner: qr, client: self.clone() }.into()
}
})
} }
/// Get a `VerificationRequest` object with the given flow id. /// Get a `VerificationRequest` object for the given user with the given
/// flow id.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification_request( pub async fn get_verification_request(
&self, &self,
user_id: &UserId,
flow_id: impl AsRef<str>, flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> { ) -> Option<VerificationRequest> {
let olm = self.base_client.olm_machine().await?; let olm = self.base_client.olm_machine().await?;
olm.get_verification_request(flow_id) olm.get_verification_request(user_id, flow_id)
.map(|r| VerificationRequest { inner: r, client: self.clone() }) .map(|r| VerificationRequest { inner: r, client: self.clone() })
} }
@ -2231,7 +2240,7 @@ impl Client {
/// ///
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, identifiers::UserId}; /// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap(); /// # let alice = UserId::try_from("@alice:example.org").unwrap();
@ -2243,7 +2252,7 @@ impl Client {
/// .unwrap() /// .unwrap()
/// .unwrap(); /// .unwrap();
/// ///
/// println!("{:?}", device.is_trusted()); /// println!("{:?}", device.verified());
/// ///
/// let verification = device.start_verification().await.unwrap(); /// let verification = device.start_verification().await.unwrap();
/// # }); /// # });
@ -2273,8 +2282,8 @@ impl Client {
/// # Examples /// # Examples
/// ```no_run /// ```no_run
/// # use std::{convert::TryFrom, collections::BTreeMap}; /// # use std::{convert::TryFrom, collections::BTreeMap};
/// # use matrix_sdk::{Client, identifiers::UserId}; /// # use matrix_sdk::{Client, ruma::UserId};
/// # use matrix_sdk::api::r0::uiaa::AuthData; /// # use matrix_sdk::ruma::api::client::r0::uiaa::AuthData;
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use serde_json::json; /// # use serde_json::json;
@ -2344,7 +2353,7 @@ impl Client {
/// ///
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, identifiers::UserId}; /// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap(); /// # let alice = UserId::try_from("@alice:example.org").unwrap();
@ -2398,7 +2407,7 @@ impl Client {
/// # use std::{path::PathBuf, time::Duration}; /// # use std::{path::PathBuf, time::Duration};
/// # use matrix_sdk::{ /// # use matrix_sdk::{
/// # Client, SyncSettings, /// # Client, SyncSettings,
/// # identifiers::room_id, /// # ruma::room_id,
/// # }; /// # };
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use url::Url; /// # use url::Url;
@ -2422,8 +2431,7 @@ impl Client {
/// .expect("Can't export keys."); /// .expect("Can't export keys.");
/// # }); /// # });
/// ``` /// ```
#[cfg(feature = "encryption")] #[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(feature = "docs", doc(cfg(all(encryption, not(target_arch = "wasm32")))))] #[cfg_attr(feature = "docs", doc(cfg(all(encryption, not(target_arch = "wasm32")))))]
pub async fn export_keys( pub async fn export_keys(
&self, &self,
@ -2468,7 +2476,7 @@ impl Client {
/// # use std::{path::PathBuf, time::Duration}; /// # use std::{path::PathBuf, time::Duration};
/// # use matrix_sdk::{ /// # use matrix_sdk::{
/// # Client, SyncSettings, /// # Client, SyncSettings,
/// # identifiers::room_id, /// # ruma::room_id,
/// # }; /// # };
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use url::Url; /// # use url::Url;
@ -2482,11 +2490,14 @@ impl Client {
/// .expect("Can't import keys"); /// .expect("Can't import keys");
/// # }); /// # });
/// ``` /// ```
#[cfg(feature = "encryption")] #[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(feature = "docs", doc(cfg(all(encryption, not(target_arch = "wasm32")))))] #[cfg_attr(feature = "docs", doc(cfg(all(encryption, not(target_arch = "wasm32")))))]
pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<(usize, usize)> { pub async fn import_keys(
let olm = self.base_client.olm_machine().await.ok_or(Error::AuthenticationRequired)?; &self,
path: PathBuf,
passphrase: &str,
) -> StdResult<(usize, usize), RoomKeyImportError> {
let olm = self.base_client.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?;
let passphrase = Zeroizing::new(passphrase.to_owned()); let passphrase = Zeroizing::new(passphrase.to_owned());
let decrypt = move || { let decrypt = move || {
@ -2495,8 +2506,7 @@ impl Client {
}; };
let task = tokio::task::spawn_blocking(decrypt); let task = tokio::task::spawn_blocking(decrypt);
// TODO remove this unwrap. let import = task.await.expect("Task join error")?;
let import = task.await.expect("Task join error").unwrap();
Ok(olm.import_keys(import, |_, _| {}).await?) Ok(olm.import_keys(import, |_, _| {}).await?)
} }
@ -2704,6 +2714,23 @@ impl Client {
let request = whoami::Request::new(); let request = whoami::Request::new();
self.send(request, None).await self.send(request, None).await
} }
#[cfg(feature = "encryption")]
pub(crate) async fn send_verification_request(
&self,
request: matrix_sdk_base::crypto::OutgoingVerificationRequest,
) -> Result<()> {
match request {
matrix_sdk_base::crypto::OutgoingVerificationRequest::ToDevice(t) => {
self.send_to_device(&t).await?;
}
matrix_sdk_base::crypto::OutgoingVerificationRequest::InRoom(r) => {
self.room_send_helper(&r).await?;
}
}
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -18,9 +18,13 @@ use matrix_sdk_base::crypto::{
store::CryptoStoreError, Device as BaseDevice, LocalTrust, ReadOnlyDevice, store::CryptoStoreError, Device as BaseDevice, LocalTrust, ReadOnlyDevice,
UserDevices as BaseUserDevices, UserDevices as BaseUserDevices,
}; };
use ruma::{DeviceId, DeviceIdBox}; use ruma::{events::key::verification::VerificationMethod, DeviceId, DeviceIdBox};
use crate::{error::Result, Client, Sas}; use crate::{
error::Result,
verification::{SasVerification, VerificationRequest},
Client,
};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
/// A device represents a E2EE capable client of an user. /// A device represents a E2EE capable client of an user.
@ -43,11 +47,14 @@ impl Device {
/// Returns a `Sas` object that represents the interactive verification /// Returns a `Sas` object that represents the interactive verification
/// flow. /// flow.
/// ///
/// # Example /// This method has been deprecated in the spec and the
/// [`request_verification()`] method should be used instead.
///
/// # Examples
/// ///
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, identifiers::UserId}; /// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap(); /// # let alice = UserId::try_from("@alice:example.org").unwrap();
@ -62,16 +69,106 @@ impl Device {
/// let verification = device.start_verification().await.unwrap(); /// let verification = device.start_verification().await.unwrap();
/// # }); /// # });
/// ``` /// ```
pub async fn start_verification(&self) -> Result<Sas> { ///
/// [`request_verification()`]: #method.request_verification
pub async fn start_verification(&self) -> Result<SasVerification> {
let (sas, request) = self.inner.start_verification().await?; let (sas, request) = self.inner.start_verification().await?;
self.client.send_to_device(&request).await?; self.client.send_to_device(&request).await?;
Ok(Sas { inner: sas, client: self.client.clone() }) Ok(SasVerification { inner: sas, client: self.client.clone() })
} }
/// Is the device trusted. /// Request an interacitve verification with this `Device`
pub fn is_trusted(&self) -> bool { ///
self.inner.trust_state() /// Returns a `VerificationRequest` object and a to-device request that
/// needs to be sent out.
///
/// The default methods that are supported are `m.sas.v1` and
/// `m.qr_code.show.v1`, if this isn't desireable the
/// [`request_verification_with_methods()`] method can be used to override
/// this.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let device = client.get_device(&alice, "DEVICEID".into())
/// .await
/// .unwrap()
/// .unwrap();
///
/// let verification = device.request_verification().await.unwrap();
/// # });
/// ```
///
/// [`request_verification_with_methods()`]:
/// #method.request_verification_with_methods
pub async fn request_verification(&self) -> Result<VerificationRequest> {
let (verification, request) = self.inner.request_verification().await;
self.client.send_verification_request(request).await?;
Ok(VerificationRequest { inner: verification, client: self.client.clone() })
}
/// Request an interacitve verification with this `Device`
///
/// Returns a `VerificationRequest` object and a to-device request that
/// needs to be sent out.
///
/// # Arguments
///
/// * `methods` - The verification methods that we want to support.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{
/// # Client,
/// # ruma::{
/// # UserId,
/// # events::key::verification::VerificationMethod,
/// # }
/// # };
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let device = client.get_device(&alice, "DEVICEID".into())
/// .await
/// .unwrap()
/// .unwrap();
///
/// // We don't want to support showing a QR code, we only support SAS
/// // verification
/// let methods = vec![VerificationMethod::SasV1];
///
/// let verification = device.request_verification_with_methods(methods).await.unwrap();
/// # });
/// ```
pub async fn request_verification_with_methods(
&self,
methods: Vec<VerificationMethod>,
) -> Result<VerificationRequest> {
let (verification, request) = self.inner.request_verification_with_methods(methods).await;
self.client.send_verification_request(request).await?;
Ok(VerificationRequest { inner: verification, client: self.client.clone() })
}
/// Is the device considered to be verified, either by locally trusting it
/// or using cross signing.
pub fn verified(&self) -> bool {
self.inner.verified()
} }
/// Set the local trust state of the device to the given state. /// Set the local trust state of the device to the given state.

View File

@ -18,8 +18,10 @@ use std::io::Error as IoError;
use http::StatusCode; use http::StatusCode;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::{store::CryptoStoreError, DecryptorError}; use matrix_sdk_base::crypto::{
use matrix_sdk_base::{Error as MatrixError, StoreError}; CryptoStoreError, DecryptorError, KeyExportError, MegolmError, OlmError,
};
use matrix_sdk_base::{Error as SdkBaseError, StoreError};
use reqwest::Error as ReqwestError; use reqwest::Error as ReqwestError;
use ruma::{ use ruma::{
api::{ api::{
@ -114,17 +116,27 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
Io(#[from] IoError), Io(#[from] IoError),
/// An error occurred in the Matrix client library.
#[error(transparent)]
MatrixError(#[from] MatrixError),
/// An error occurred in the crypto store. /// An error occurred in the crypto store.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[error(transparent)] #[error(transparent)]
CryptoStoreError(#[from] CryptoStoreError), CryptoStoreError(#[from] CryptoStoreError),
/// An error occurred during a E2EE operation.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[error(transparent)]
OlmError(#[from] OlmError),
/// An error occurred during a E2EE group operation.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[error(transparent)]
MegolmError(#[from] MegolmError),
/// An error occurred during decryption. /// An error occurred during decryption.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[error(transparent)] #[error(transparent)]
DecryptorError(#[from] DecryptorError), DecryptorError(#[from] DecryptorError),
@ -141,6 +153,35 @@ pub enum Error {
Url(#[from] UrlParseError), Url(#[from] UrlParseError),
} }
/// Error for the room key importing functionality.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[derive(Error, Debug)]
// This is allowed because key importing isn't enabled under wasm.
#[allow(dead_code)]
pub enum RoomKeyImportError {
/// An error de/serializing type for the `StateStore`
#[error(transparent)]
SerdeJson(#[from] JsonError),
/// The cryptostore isn't yet open, logging in is required to open the
/// cryptostore.
#[error("The cryptostore hasn't been yet opened, can't import yet.")]
StoreClosed,
/// An IO error happened.
#[error(transparent)]
Io(#[from] IoError),
/// An error occurred in the crypto store.
#[error(transparent)]
CryptoStore(#[from] CryptoStoreError),
/// An error occurred while importing the key export.
#[error(transparent)]
Export(#[from] KeyExportError),
}
impl Error { impl Error {
/// Try to destructure the error into an universal interactive auth info. /// Try to destructure the error into an universal interactive auth info.
/// ///
@ -165,6 +206,23 @@ impl Error {
} }
} }
impl From<SdkBaseError> for Error {
fn from(e: SdkBaseError) -> Self {
match e {
SdkBaseError::AuthenticationRequired => Self::AuthenticationRequired,
SdkBaseError::StateStore(e) => Self::StateStore(e),
SdkBaseError::SerdeJson(e) => Self::SerdeJson(e),
SdkBaseError::IoError(e) => Self::Io(e),
#[cfg(feature = "encryption")]
SdkBaseError::CryptoStore(e) => Self::CryptoStoreError(e),
#[cfg(feature = "encryption")]
SdkBaseError::OlmError(e) => Self::OlmError(e),
#[cfg(feature = "encryption")]
SdkBaseError::MegolmError(e) => Self::MegolmError(e),
}
}
}
impl From<ReqwestError> for Error { impl From<ReqwestError> for Error {
fn from(e: ReqwestError) -> Self { fn from(e: ReqwestError) -> Self {
Error::Http(HttpError::Reqwest(e)) Error::Http(HttpError::Reqwest(e))

View File

@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
use std::ops::Deref; use std::ops::Deref;
use matrix_sdk_base::{hoist_and_deserialize_state_event, hoist_room_event_prev_content};
use matrix_sdk_common::async_trait; use matrix_sdk_common::async_trait;
use ruma::{ use ruma::{
api::client::r0::push::get_notifications::Notification, api::client::r0::push::get_notifications::Notification,
@ -27,6 +28,7 @@ use ruma::{
ignored_user_list::IgnoredUserListEventContent, ignored_user_list::IgnoredUserListEventContent,
presence::PresenceEvent, presence::PresenceEvent,
push_rules::PushRulesEventContent, push_rules::PushRulesEventContent,
reaction::ReactionEventContent,
receipt::ReceiptEventContent, receipt::ReceiptEventContent,
room::{ room::{
aliases::AliasesEventContent, aliases::AliasesEventContent,
@ -46,6 +48,7 @@ use ruma::{
GlobalAccountDataEvent, RoomAccountDataEvent, StrippedStateEvent, SyncEphemeralRoomEvent, GlobalAccountDataEvent, RoomAccountDataEvent, StrippedStateEvent, SyncEphemeralRoomEvent,
SyncMessageEvent, SyncStateEvent, SyncMessageEvent, SyncStateEvent,
}, },
serde::Raw,
RoomId, RoomId,
}; };
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
@ -88,14 +91,24 @@ impl Handler {
self.handle_room_account_data_event(room.clone(), &event).await; self.handle_room_account_data_event(room.clone(), &event).await;
} }
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) { for (raw_event, event) in room_info.state.events.iter().filter_map(|e| {
self.handle_state_event(room.clone(), &event).await; if let Ok(d) = hoist_and_deserialize_state_event(e) {
Some((e, d))
} else {
None
}
}) {
self.handle_state_event(room.clone(), &event, raw_event).await;
} }
for event in for (raw_event, event) in room_info.timeline.events.iter().filter_map(|e| {
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok()) if let Ok(d) = hoist_room_event_prev_content(&e.event) {
{ Some((&e.event, d))
self.handle_timeline_event(room.clone(), &event).await; } else {
None
}
}) {
self.handle_timeline_event(room.clone(), &event, raw_event).await;
} }
} }
} }
@ -108,14 +121,24 @@ impl Handler {
self.handle_room_account_data_event(room.clone(), &event).await; self.handle_room_account_data_event(room.clone(), &event).await;
} }
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) { for (raw_event, event) in room_info.state.events.iter().filter_map(|e| {
self.handle_state_event(room.clone(), &event).await; if let Ok(d) = hoist_and_deserialize_state_event(e) {
Some((e, d))
} else {
None
}
}) {
self.handle_state_event(room.clone(), &event, raw_event).await;
} }
for event in for (raw_event, event) in room_info.timeline.events.iter().filter_map(|e| {
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok()) if let Ok(d) = hoist_room_event_prev_content(&e.event) {
{ Some((&e.event, d))
self.handle_timeline_event(room.clone(), &event).await; } else {
None
}
}) {
self.handle_timeline_event(room.clone(), &event, raw_event).await;
} }
} }
} }
@ -143,7 +166,12 @@ impl Handler {
} }
} }
async fn handle_timeline_event(&self, room: Room, event: &AnySyncRoomEvent) { async fn handle_timeline_event(
&self,
room: Room,
event: &AnySyncRoomEvent,
raw_event: &Raw<AnySyncRoomEvent>,
) {
match event { match event {
AnySyncRoomEvent::State(event) => match event { AnySyncRoomEvent::State(event) => match event {
AnySyncStateEvent::RoomMember(e) => self.on_room_member(room, e).await, AnySyncStateEvent::RoomMember(e) => self.on_room_member(room, e).await,
@ -156,10 +184,25 @@ impl Handler {
AnySyncStateEvent::RoomPowerLevels(e) => self.on_room_power_levels(room, e).await, AnySyncStateEvent::RoomPowerLevels(e) => self.on_room_power_levels(room, e).await,
AnySyncStateEvent::RoomTombstone(e) => self.on_room_tombstone(room, e).await, AnySyncStateEvent::RoomTombstone(e) => self.on_room_tombstone(room, e).await,
AnySyncStateEvent::RoomJoinRules(e) => self.on_room_join_rules(room, e).await, AnySyncStateEvent::RoomJoinRules(e) => self.on_room_join_rules(room, e).await,
AnySyncStateEvent::Custom(e) => { AnySyncStateEvent::PolicyRuleRoom(_)
self.on_custom_event(room, &CustomEvent::State(e)).await | AnySyncStateEvent::PolicyRuleServer(_)
| AnySyncStateEvent::PolicyRuleUser(_)
| AnySyncStateEvent::RoomCreate(_)
| AnySyncStateEvent::RoomEncryption(_)
| AnySyncStateEvent::RoomGuestAccess(_)
| AnySyncStateEvent::RoomHistoryVisibility(_)
| AnySyncStateEvent::RoomPinnedEvents(_)
| AnySyncStateEvent::RoomServerAcl(_)
| AnySyncStateEvent::RoomThirdPartyInvite(_)
| AnySyncStateEvent::RoomTopic(_)
| AnySyncStateEvent::SpaceChild(_)
| AnySyncStateEvent::SpaceParent(_) => {}
_ => {
if let Ok(e) = raw_event.deserialize_as::<SyncStateEvent<CustomEventContent>>()
{
self.on_custom_event(room, &CustomEvent::State(&e)).await;
}
} }
_ => {}
}, },
AnySyncRoomEvent::Message(event) => match event { AnySyncRoomEvent::Message(event) => match event {
AnySyncMessageEvent::RoomMessage(e) => self.on_room_message(room, e).await, AnySyncMessageEvent::RoomMessage(e) => self.on_room_message(room, e).await,
@ -167,23 +210,41 @@ impl Handler {
self.on_room_message_feedback(room, e).await self.on_room_message_feedback(room, e).await
} }
AnySyncMessageEvent::RoomRedaction(e) => self.on_room_redaction(room, e).await, AnySyncMessageEvent::RoomRedaction(e) => self.on_room_redaction(room, e).await,
AnySyncMessageEvent::Custom(e) => { AnySyncMessageEvent::Reaction(e) => self.on_room_reaction(room, e).await,
self.on_custom_event(room, &CustomEvent::Message(e)).await
}
AnySyncMessageEvent::CallInvite(e) => self.on_room_call_invite(room, e).await, AnySyncMessageEvent::CallInvite(e) => self.on_room_call_invite(room, e).await,
AnySyncMessageEvent::CallAnswer(e) => self.on_room_call_answer(room, e).await, AnySyncMessageEvent::CallAnswer(e) => self.on_room_call_answer(room, e).await,
AnySyncMessageEvent::CallCandidates(e) => { AnySyncMessageEvent::CallCandidates(e) => {
self.on_room_call_candidates(room, e).await self.on_room_call_candidates(room, e).await
} }
AnySyncMessageEvent::CallHangup(e) => self.on_room_call_hangup(room, e).await, AnySyncMessageEvent::CallHangup(e) => self.on_room_call_hangup(room, e).await,
_ => {} AnySyncMessageEvent::KeyVerificationReady(_)
| AnySyncMessageEvent::KeyVerificationStart(_)
| AnySyncMessageEvent::KeyVerificationCancel(_)
| AnySyncMessageEvent::KeyVerificationAccept(_)
| AnySyncMessageEvent::KeyVerificationKey(_)
| AnySyncMessageEvent::KeyVerificationMac(_)
| AnySyncMessageEvent::KeyVerificationDone(_)
| AnySyncMessageEvent::RoomEncrypted(_)
| AnySyncMessageEvent::Sticker(_) => {}
_ => {
if let Ok(e) =
raw_event.deserialize_as::<SyncMessageEvent<CustomEventContent>>()
{
self.on_custom_event(room, &CustomEvent::Message(&e)).await;
}
}
}, },
AnySyncRoomEvent::RedactedState(_event) => {} AnySyncRoomEvent::RedactedState(_event) => {}
AnySyncRoomEvent::RedactedMessage(_event) => {} AnySyncRoomEvent::RedactedMessage(_event) => {}
} }
} }
async fn handle_state_event(&self, room: Room, event: &AnySyncStateEvent) { async fn handle_state_event(
&self,
room: Room,
event: &AnySyncStateEvent,
raw_event: &Raw<AnySyncStateEvent>,
) {
match event { match event {
AnySyncStateEvent::RoomMember(member) => self.on_state_member(room, member).await, AnySyncStateEvent::RoomMember(member) => self.on_state_member(room, member).await,
AnySyncStateEvent::RoomName(name) => self.on_state_name(room, name).await, AnySyncStateEvent::RoomName(name) => self.on_state_name(room, name).await,
@ -200,10 +261,24 @@ impl Handler {
// TODO make `on_state_tombstone` method // TODO make `on_state_tombstone` method
self.on_room_tombstone(room, tomb).await self.on_room_tombstone(room, tomb).await
} }
AnySyncStateEvent::Custom(custom) => { AnySyncStateEvent::PolicyRuleRoom(_)
self.on_custom_event(room, &CustomEvent::State(custom)).await | AnySyncStateEvent::PolicyRuleServer(_)
| AnySyncStateEvent::PolicyRuleUser(_)
| AnySyncStateEvent::RoomCreate(_)
| AnySyncStateEvent::RoomEncryption(_)
| AnySyncStateEvent::RoomGuestAccess(_)
| AnySyncStateEvent::RoomHistoryVisibility(_)
| AnySyncStateEvent::RoomPinnedEvents(_)
| AnySyncStateEvent::RoomServerAcl(_)
| AnySyncStateEvent::RoomThirdPartyInvite(_)
| AnySyncStateEvent::RoomTopic(_)
| AnySyncStateEvent::SpaceChild(_)
| AnySyncStateEvent::SpaceParent(_) => {}
_ => {
if let Ok(e) = raw_event.deserialize_as::<SyncStateEvent<CustomEventContent>>() {
self.on_custom_event(room, &CustomEvent::State(&e)).await;
}
} }
_ => {}
} }
} }
@ -301,7 +376,7 @@ pub enum CustomEvent<'c> {
/// # use matrix_sdk::{ /// # use matrix_sdk::{
/// # async_trait, /// # async_trait,
/// # EventHandler, /// # EventHandler,
/// # events::{ /// # ruma::events::{
/// # room::message::{MessageEventContent, MessageType, TextMessageEventContent}, /// # room::message::{MessageEventContent, MessageType, TextMessageEventContent},
/// # SyncMessageEvent /// # SyncMessageEvent
/// # }, /// # },
@ -358,6 +433,8 @@ pub trait EventHandler: Send + Sync {
async fn on_room_message(&self, _: Room, _: &SyncMessageEvent<MsgEventContent>) {} async fn on_room_message(&self, _: Room, _: &SyncMessageEvent<MsgEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomMessageFeedback` event. /// Fires when `Client` receives a `RoomEvent::RoomMessageFeedback` event.
async fn on_room_message_feedback(&self, _: Room, _: &SyncMessageEvent<FeedbackEventContent>) {} async fn on_room_message_feedback(&self, _: Room, _: &SyncMessageEvent<FeedbackEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::Reaction` event.
async fn on_room_reaction(&self, _: Room, _: &SyncMessageEvent<ReactionEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::CallInvite` event /// Fires when `Client` receives a `RoomEvent::CallInvite` event
async fn on_room_call_invite(&self, _: Room, _: &SyncMessageEvent<InviteEventContent>) {} async fn on_room_call_invite(&self, _: Room, _: &SyncMessageEvent<InviteEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::CallAnswer` event /// Fires when `Client` receives a `RoomEvent::CallAnswer` event

View File

@ -12,15 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(all(not(target_arch = "wasm32")))]
use std::sync::atomic::{AtomicU64, Ordering};
use std::{convert::TryFrom, fmt::Debug, sync::Arc}; use std::{convert::TryFrom, fmt::Debug, sync::Arc};
#[cfg(all(not(target_arch = "wasm32")))] use bytes::{Bytes, BytesMut};
use backoff::{future::retry, Error as RetryError, ExponentialBackoff}; use http::Response as HttpResponse;
#[cfg(all(not(target_arch = "wasm32")))]
use http::StatusCode;
use http::{HeaderValue, Response as HttpResponse};
use matrix_sdk_common::{async_trait, locks::RwLock, AsyncTraitDeps}; use matrix_sdk_common::{async_trait, locks::RwLock, AsyncTraitDeps};
use reqwest::{Client, Response}; use reqwest::{Client, Response};
use ruma::api::{ use ruma::api::{
@ -30,7 +25,7 @@ use ruma::api::{
use tracing::trace; use tracing::trace;
use url::Url; use url::Url;
use crate::{error::HttpError, Bytes, BytesMut, ClientConfig, RequestConfig, Session}; use crate::{error::HttpError, ClientConfig, RequestConfig, Session};
/// Abstraction around the http layer. The allows implementors to use different /// Abstraction around the http layer. The allows implementors to use different
/// http libraries. /// http libraries.
@ -54,7 +49,7 @@ pub trait HttpSend: AsyncTraitDeps {
/// ///
/// ``` /// ```
/// use std::convert::TryFrom; /// use std::convert::TryFrom;
/// use matrix_sdk::{HttpSend, async_trait, HttpError, RequestConfig, Bytes}; /// use matrix_sdk::{HttpSend, async_trait, HttpError, RequestConfig, bytes::Bytes};
/// ///
/// #[derive(Debug)] /// #[derive(Debug)]
/// struct Client(reqwest::Client); /// struct Client(reqwest::Client);
@ -235,6 +230,8 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result<Client, HttpEr
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
let http_client = { let http_client = {
use http::HeaderValue;
let http_client = if config.disable_ssl_verification { let http_client = if config.disable_ssl_verification {
http_client.danger_accept_invalid_certs(true) http_client.danger_accept_invalid_certs(true)
} else { } else {
@ -303,6 +300,11 @@ async fn send_request(
request: http::Request<Bytes>, request: http::Request<Bytes>,
config: RequestConfig, config: RequestConfig,
) -> Result<http::Response<Bytes>, HttpError> { ) -> Result<http::Response<Bytes>, HttpError> {
use std::sync::atomic::{AtomicU64, Ordering};
use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
use http::StatusCode;
let mut backoff = ExponentialBackoff::default(); let mut backoff = ExponentialBackoff::default();
let mut request = reqwest::Request::try_from(request)?; let mut request = reqwest::Request::try_from(request)?;
let retry_limit = config.retry_limit; let retry_limit = config.retry_limit;

View File

@ -75,33 +75,18 @@ compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled
#[cfg(all(feature = "sso_login", target_arch = "wasm32"))] #[cfg(all(feature = "sso_login", target_arch = "wasm32"))]
compile_error!("'sso_login' cannot be enabled on 'wasm32' arch"); compile_error!("'sso_login' cannot be enabled on 'wasm32' arch");
pub use bytes::{Bytes, BytesMut}; pub use bytes;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust}; pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};
pub use matrix_sdk_base::{ pub use matrix_sdk_base::{
media, Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, media, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, Session,
Session, StateChanges, StoreError, StateChanges, StoreError,
}; };
pub use matrix_sdk_common::*; pub use matrix_sdk_common::*;
pub use reqwest; pub use reqwest;
#[cfg(feature = "appservice")] #[doc(no_inline)]
pub use ruma::{ pub use ruma;
api::{appservice as api_appservice, IncomingRequest, OutgoingRequestAppserviceExt},
serde::{exports::serde::de::value::Error as SerdeError, urlencoded},
};
pub use ruma::{
api::{
client as api,
error::{
FromHttpRequestError, FromHttpResponseError, IntoHttpError, MatrixError, ServerError,
},
AuthScheme, EndpointError, IncomingResponse, OutgoingRequest, SendAccessToken,
},
assign, directory, encryption, events, identifiers, int, presence, push, receipt,
serde::{CanonicalJsonValue, Raw},
thirdparty, uint, Int, MilliSecondsSinceUnixEpoch, Outgoing, SecondsSinceUnixEpoch, UInt,
};
mod client; mod client;
mod error; mod error;
@ -115,9 +100,7 @@ mod room_member;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
mod device; mod device;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
mod sas; pub mod verification;
#[cfg(feature = "encryption")]
mod verification_request;
pub use client::{Client, ClientConfig, LoopCtrl, RequestConfig, SyncSettings}; pub use client::{Client, ClientConfig, LoopCtrl, RequestConfig, SyncSettings};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
@ -127,12 +110,5 @@ pub use error::{Error, HttpError, Result};
pub use event_handler::{CustomEvent, EventHandler}; pub use event_handler::{CustomEvent, EventHandler};
pub use http_client::HttpSend; pub use http_client::HttpSend;
pub use room_member::RoomMember; pub use room_member::RoomMember;
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use sas::Sas;
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use verification_request::VerificationRequest;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub(crate) const VERSION: &str = env!("CARGO_PKG_VERSION"); pub(crate) const VERSION: &str = env!("CARGO_PKG_VERSION");

View File

@ -4,16 +4,19 @@ use matrix_sdk_base::deserialized_responses::MembersResponse;
use matrix_sdk_common::locks::Mutex; use matrix_sdk_common::locks::Mutex;
use ruma::{ use ruma::{
api::client::r0::{ api::client::r0::{
media::{get_content, get_content_thumbnail},
membership::{get_member_events, join_room_by_id, leave_room}, membership::{get_member_events, join_room_by_id, leave_room},
message::get_message_events, message::get_message_events,
}, },
events::{AnySyncStateEvent, EventType}, events::{room::history_visibility::HistoryVisibility, AnySyncStateEvent, EventType},
serde::Raw, serde::Raw,
UserId, UserId,
}; };
use crate::{BaseRoom, Client, Result, RoomMember}; use crate::{
media::{MediaFormat, MediaRequest, MediaType},
room::RoomType,
BaseRoom, Client, Result, RoomMember,
};
/// A struct containing methods that are common for Joined, Invited and Left /// A struct containing methods that are common for Joined, Invited and Left
/// Rooms /// Rooms
@ -65,20 +68,20 @@ impl Common {
/// Gets the avatar of this room, if set. /// Gets the avatar of this room, if set.
/// ///
/// Returns the avatar. No guarantee on the size of the image is given. /// Returns the avatar.
/// If no size is given the full-sized avatar will be returned. /// If a thumbnail is requested no guarantee on the size of the image is
/// given.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `width` - The desired width of the avatar. /// * `format` - The desired format of the avatar.
///
/// * `height` - The desired height of the avatar.
/// ///
/// # Example /// # Example
/// ```no_run /// ```no_run
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use matrix_sdk::identifiers::room_id; /// # use matrix_sdk::ruma::room_id;
/// # use matrix_sdk::media::MediaFormat;
/// # use url::Url; /// # use url::Url;
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
/// # block_on(async { /// # block_on(async {
@ -89,24 +92,15 @@ impl Common {
/// let room = client /// let room = client
/// .get_joined_room(&room_id) /// .get_joined_room(&room_id)
/// .unwrap(); /// .unwrap();
/// if let Some(avatar) = room.avatar(Some(96), Some(96)).await.unwrap() { /// if let Some(avatar) = room.avatar(MediaFormat::File).await.unwrap() {
/// std::fs::write("avatar.png", avatar); /// std::fs::write("avatar.png", avatar);
/// } /// }
/// # }) /// # })
/// ``` /// ```
pub async fn avatar(&self, width: Option<u32>, height: Option<u32>) -> Result<Option<Vec<u8>>> { pub async fn avatar(&self, format: MediaFormat) -> Result<Option<Vec<u8>>> {
// TODO: try to offer the avatar from cache, requires avatar cache
if let Some(url) = self.avatar_url() { if let Some(url) = self.avatar_url() {
if let (Some(width), Some(height)) = (width, height) { let request = MediaRequest { media_type: MediaType::Uri(url.clone()), format };
let request = Ok(Some(self.client.get_media_content(&request, true).await?))
get_content_thumbnail::Request::from_url(&url, width.into(), height.into())?;
let response = self.client.send(request, None).await?;
Ok(Some(response.file))
} else {
let request = get_content::Request::from_url(&url)?;
let response = self.client.send(request, None).await?;
Ok(Some(response.file))
}
} else { } else {
Ok(None) Ok(None)
} }
@ -125,9 +119,11 @@ impl Common {
/// ```no_run /// ```no_run
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// use matrix_sdk::Client; /// use matrix_sdk::Client;
/// # use matrix_sdk::identifiers::room_id; /// # use matrix_sdk::ruma::room_id;
/// # use matrix_sdk::api::r0::filter::RoomEventFilter; /// # use matrix_sdk::ruma::api::client::r0::{
/// # use matrix_sdk::api::r0::message::get_message_events::Request as MessagesRequest; /// # filter::RoomEventFilter,
/// # message::get_message_events::Request as MessagesRequest,
/// # };
/// # use url::Url; /// # use url::Url;
/// ///
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
@ -178,6 +174,10 @@ impl Common {
} }
async fn ensure_members(&self) -> Result<()> { async fn ensure_members(&self) -> Result<()> {
if !self.are_events_visible() {
return Ok(());
}
if !self.are_members_synced() { if !self.are_members_synced() {
self.request_members().await?; self.request_members().await?;
} }
@ -185,6 +185,17 @@ impl Common {
Ok(()) Ok(())
} }
fn are_events_visible(&self) -> bool {
if let RoomType::Invited = self.inner.room_type() {
return matches!(
self.inner.history_visibility(),
HistoryVisibility::WorldReadable | HistoryVisibility::Invited
);
}
true
}
/// Sync the member list with the server. /// Sync the member list with the server.
/// ///
/// This method will de-duplicate requests if it is called multiple times in /// This method will de-duplicate requests if it is called multiple times in

View File

@ -38,8 +38,8 @@ use ruma::{
}, },
AnyMessageEventContent, AnyStateEventContent, AnyMessageEventContent, AnyStateEventContent,
}, },
identifiers::{EventId, UserId},
receipt::ReceiptType, receipt::ReceiptType,
EventId, UserId,
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use tracing::instrument; use tracing::instrument;
@ -159,10 +159,10 @@ impl Joined {
/// ///
/// ```no_run /// ```no_run
/// use std::time::Duration; /// use std::time::Duration;
/// use matrix_sdk::api::r0::typing::create_typing_event::Typing; /// use matrix_sdk::ruma::api::client::r0::typing::create_typing_event::Typing;
/// # use matrix_sdk::{ /// # use matrix_sdk::{
/// # Client, SyncSettings, /// # Client, SyncSettings,
/// # identifiers::room_id, /// # ruma::room_id,
/// # }; /// # };
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use url::Url; /// # use url::Url;
@ -349,9 +349,9 @@ impl Joined {
/// # use matrix_sdk::{Client, SyncSettings}; /// # use matrix_sdk::{Client, SyncSettings};
/// # use url::Url; /// # use url::Url;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use matrix_sdk::identifiers::room_id; /// # use matrix_sdk::ruma::room_id;
/// # use std::convert::TryFrom; /// # use std::convert::TryFrom;
/// use matrix_sdk::events::{ /// use matrix_sdk::ruma::events::{
/// AnyMessageEventContent, /// AnyMessageEventContent,
/// room::message::{MessageEventContent, TextMessageEventContent}, /// room::message::{MessageEventContent, TextMessageEventContent},
/// }; /// };
@ -431,7 +431,7 @@ impl Joined {
/// ///
/// ```no_run /// ```no_run
/// # use std::{path::PathBuf, fs::File, io::Read}; /// # use std::{path::PathBuf, fs::File, io::Read};
/// # use matrix_sdk::{Client, identifiers::room_id}; /// # use matrix_sdk::{Client, ruma::room_id};
/// # use url::Url; /// # use url::Url;
/// # use mime; /// # use mime;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
@ -532,18 +532,17 @@ impl Joined {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use matrix_sdk::{ /// use matrix_sdk::ruma::{
/// events::{ /// events::{
/// AnyStateEventContent, /// AnyStateEventContent,
/// room::member::{MemberEventContent, MembershipState}, /// room::member::{MemberEventContent, MembershipState},
/// }, /// },
/// identifiers::mxc_uri, /// assign, mxc_uri,
/// assign,
/// }; /// };
/// # futures::executor::block_on(async { /// # futures::executor::block_on(async {
/// # let homeserver = url::Url::parse("http://localhost:8080").unwrap(); /// # let homeserver = url::Url::parse("http://localhost:8080").unwrap();
/// # let mut client = matrix_sdk::Client::new(homeserver).unwrap(); /// # let mut client = matrix_sdk::Client::new(homeserver).unwrap();
/// # let room_id = matrix_sdk::identifiers::room_id!("!test:localhost"); /// # let room_id = matrix_sdk::ruma::room_id!("!test:localhost");
/// ///
/// let avatar_url = mxc_uri!("mxc://example.org/avatar"); /// let avatar_url = mxc_uri!("mxc://example.org/avatar");
/// let member_event = assign!(MemberEventContent::new(MembershipState::Join), { /// let member_event = assign!(MemberEventContent::new(MembershipState::Join), {
@ -591,11 +590,11 @@ impl Joined {
/// # futures::executor::block_on(async { /// # futures::executor::block_on(async {
/// # let homeserver = url::Url::parse("http://localhost:8080").unwrap(); /// # let homeserver = url::Url::parse("http://localhost:8080").unwrap();
/// # let mut client = matrix_sdk::Client::new(homeserver).unwrap(); /// # let mut client = matrix_sdk::Client::new(homeserver).unwrap();
/// # let room_id = matrix_sdk::identifiers::room_id!("!test:localhost"); /// # let room_id = matrix_sdk::ruma::room_id!("!test:localhost");
/// # let room = client /// # let room = client
/// # .get_joined_room(&room_id) /// # .get_joined_room(&room_id)
/// # .unwrap(); /// # .unwrap();
/// let event_id = matrix_sdk::identifiers::event_id!("$xxxxxx:example.org"); /// let event_id = matrix_sdk::ruma::event_id!("$xxxxxx:example.org");
/// let reason = Some("Indecent material"); /// let reason = Some("Indecent material");
/// room.redact(&event_id, reason, None).await.unwrap(); /// room.redact(&event_id, reason, None).await.unwrap();
/// # }) /// # })

View File

@ -1,8 +1,9 @@
use std::ops::Deref; use std::ops::Deref;
use ruma::api::client::r0::media::{get_content, get_content_thumbnail}; use crate::{
media::{MediaFormat, MediaRequest, MediaType},
use crate::{BaseRoomMember, Client, Result}; BaseRoomMember, Client, Result,
};
/// The high-level `RoomMember` representation /// The high-level `RoomMember` representation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -26,21 +27,21 @@ impl RoomMember {
/// Gets the avatar of this member, if set. /// Gets the avatar of this member, if set.
/// ///
/// Returns the avatar. No guarantee on the size of the image is given. /// Returns the avatar.
/// If no size is given the full-sized avatar will be returned. /// If a thumbnail is requested no guarantee on the size of the image is
/// given.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `width` - The desired width of the avatar. /// * `format` - The desired format of the avatar.
///
/// * `height` - The desired height of the avatar.
/// ///
/// # Example /// # Example
/// ```no_run /// ```no_run
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use matrix_sdk::identifiers::room_id; /// # use matrix_sdk::ruma::room_id;
/// # use matrix_sdk::RoomMember; /// # use matrix_sdk::RoomMember;
/// # use matrix_sdk::media::MediaFormat;
/// # use url::Url; /// # use url::Url;
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
/// # block_on(async { /// # block_on(async {
@ -53,24 +54,15 @@ impl RoomMember {
/// .unwrap(); /// .unwrap();
/// let members = room.members().await.unwrap(); /// let members = room.members().await.unwrap();
/// let member = members.first().unwrap(); /// let member = members.first().unwrap();
/// if let Some(avatar) = member.avatar(Some(96), Some(96)).await.unwrap() { /// if let Some(avatar) = member.avatar(MediaFormat::File).await.unwrap() {
/// std::fs::write("avatar.png", avatar); /// std::fs::write("avatar.png", avatar);
/// } /// }
/// # }) /// # })
/// ``` /// ```
pub async fn avatar(&self, width: Option<u32>, height: Option<u32>) -> Result<Option<Vec<u8>>> { pub async fn avatar(&self, format: MediaFormat) -> Result<Option<Vec<u8>>> {
// TODO: try to offer the avatar from cache, requires avatar cache
if let Some(url) = self.avatar_url() { if let Some(url) = self.avatar_url() {
if let (Some(width), Some(height)) = (width, height) { let request = MediaRequest { media_type: MediaType::Uri(url.clone()), format };
let request = Ok(Some(self.client.get_media_content(&request, true).await?))
get_content_thumbnail::Request::from_url(url, width.into(), height.into())?;
let response = self.client.send(request, None).await?;
Ok(Some(response.file))
} else {
let request = get_content::Request::from_url(url)?;
let response = self.client.send(request, None).await?;
Ok(Some(response.file))
}
} else { } else {
Ok(None) Ok(None)
} }

View File

@ -0,0 +1,138 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Interactive verification for E2EE capable users and devices in Matrix.
//!
//! The SDK supports interactive verification of devices and users, this module
//! contains types that model and support different verification flows.
//!
//! A verification flow usually starts its life as a [VerificationRequest], the
//! request can then be accepted, or it needs to be accepted by the other side
//! of the verification flow.
//!
//! Once both sides have agreed to pereform the verification, and the
//! [VerificationRequest::is_ready()] method returns true, the verification can
//! transition into one of the supported verification flows:
//!
//! * [SasVerification] - Interactive verification using a short authentication
//! string.
//! * [QrVerification] - Interactive verification using QR codes.
mod qrcode;
mod requests;
mod sas;
pub use matrix_sdk_base::crypto::{AcceptSettings, CancelInfo};
pub use qrcode::QrVerification;
pub use requests::VerificationRequest;
pub use sas::SasVerification;
/// An enum over the different verification types the SDK supports.
#[derive(Debug, Clone)]
pub enum Verification {
/// The `m.sas.v1` verification variant.
SasV1(SasVerification),
/// The `m.qr_code.*.v1` verification variant.
QrV1(QrVerification),
}
impl Verification {
/// Try to deconstruct this verification enum into a SAS verification.
pub fn sas(self) -> Option<SasVerification> {
if let Verification::SasV1(sas) = self {
Some(sas)
} else {
None
}
}
/// Try to deconstruct this verification enum into a QR code verification.
pub fn qr(self) -> Option<QrVerification> {
if let Verification::QrV1(qr) = self {
Some(qr)
} else {
None
}
}
/// Has this verification finished.
pub fn is_done(&self) -> bool {
match self {
Verification::SasV1(s) => s.is_done(),
Verification::QrV1(qr) => qr.is_done(),
}
}
/// Has the verification been cancelled.
pub fn is_cancelled(&self) -> bool {
match self {
Verification::SasV1(s) => s.is_cancelled(),
Verification::QrV1(qr) => qr.is_cancelled(),
}
}
/// Get info about the cancellation if the verification flow has been
/// cancelled.
pub fn cancel_info(&self) -> Option<CancelInfo> {
match self {
Verification::SasV1(s) => s.cancel_info(),
Verification::QrV1(q) => q.cancel_info(),
}
}
/// Get our own user id.
pub fn own_user_id(&self) -> &ruma::UserId {
match self {
Verification::SasV1(v) => v.own_user_id(),
Verification::QrV1(v) => v.own_user_id(),
}
}
/// Get the user id of the other user participating in this verification
/// flow.
pub fn other_user_id(&self) -> &ruma::UserId {
match self {
Verification::SasV1(v) => v.inner.other_user_id(),
Verification::QrV1(v) => v.inner.other_user_id(),
}
}
/// Is this a verification that is veryfying one of our own devices.
pub fn is_self_verification(&self) -> bool {
match self {
Verification::SasV1(v) => v.is_self_verification(),
Verification::QrV1(v) => v.is_self_verification(),
}
}
/// Did we initiate the verification flow.
pub fn we_started(&self) -> bool {
match self {
Verification::SasV1(s) => s.we_started(),
Verification::QrV1(q) => q.we_started(),
}
}
}
impl From<SasVerification> for Verification {
fn from(sas: SasVerification) -> Self {
Self::SasV1(sas)
}
}
impl From<QrVerification> for Verification {
fn from(qr: QrVerification) -> Self {
Self::QrV1(qr)
}
}

View File

@ -0,0 +1,104 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_base::crypto::{
matrix_qrcode::{qrcode::QrCode, EncodingError},
CancelInfo, QrVerification as BaseQrVerification,
};
use ruma::UserId;
use crate::{Client, Result};
/// An object controlling QR code style key verification flows.
#[derive(Debug, Clone)]
pub struct QrVerification {
pub(crate) inner: BaseQrVerification,
pub(crate) client: Client,
}
impl QrVerification {
/// Get our own user id.
pub fn own_user_id(&self) -> &UserId {
self.inner.user_id()
}
/// Is this a verification that is veryfying one of our own devices.
pub fn is_self_verification(&self) -> bool {
self.inner.is_self_verification()
}
/// Has this verification finished.
pub fn is_done(&self) -> bool {
self.inner.is_done()
}
/// Did we initiate the verification flow.
pub fn we_started(&self) -> bool {
self.inner.we_started()
}
/// Get info about the cancellation if the verification flow has been
/// cancelled.
pub fn cancel_info(&self) -> Option<CancelInfo> {
self.inner.cancel_info()
}
/// Get the user id of the other user participating in this verification
/// flow.
pub fn other_user_id(&self) -> &UserId {
self.inner.other_user_id()
}
/// Has the verification been cancelled.
pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled()
}
/// Generate a QR code object that is representing this verification flow.
///
/// The `QrCode` can then be rendered as an image or as an unicode string.
///
/// The [`to_bytes()`](#method.to_bytes) method can be used to instead
/// output the raw bytes that should be encoded as a QR code.
pub fn to_qr_code(&self) -> std::result::Result<QrCode, EncodingError> {
self.inner.to_qr_code()
}
/// Generate a the raw bytes that should be encoded as a QR code is
/// representing this verification flow.
///
/// The [`to_qr_code()`](#method.to_qr_code) method can be used to instead
/// output a `QrCode` object that can be rendered.
pub fn to_bytes(&self) -> std::result::Result<Vec<u8>, EncodingError> {
self.inner.to_bytes()
}
/// Confirm that the other side has scanned our QR code.
pub async fn confirm(&self) -> Result<()> {
if let Some(request) = self.inner.confirm_scanning() {
self.client.send_verification_request(request).await?;
}
Ok(())
}
/// Abort the verification flow and notify the other side that we did so.
pub async fn cancel(&self) -> Result<()> {
if let Some(request) = self.inner.cancel() {
self.client.send_verification_request(request).await?;
}
Ok(())
}
}

View File

@ -0,0 +1,143 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_base::crypto::{CancelInfo, VerificationRequest as BaseVerificationRequest};
use ruma::events::key::verification::VerificationMethod;
use super::{QrVerification, SasVerification};
use crate::{Client, Result};
/// An object controlling the interactive verification flow.
#[derive(Debug, Clone)]
pub struct VerificationRequest {
pub(crate) inner: BaseVerificationRequest,
pub(crate) client: Client,
}
impl VerificationRequest {
/// Has this verification finished.
pub fn is_done(&self) -> bool {
self.inner.is_done()
}
/// Has the verification been cancelled.
pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled()
}
/// Get info about the cancellation if the verification request has been
/// cancelled.
pub fn cancel_info(&self) -> Option<CancelInfo> {
self.inner.cancel_info()
}
/// Get our own user id.
pub fn own_user_id(&self) -> &ruma::UserId {
self.inner.own_user_id()
}
/// Has the verification request been answered by another device.
pub fn is_passive(&self) -> bool {
self.inner.is_passive()
}
/// Is the verification request ready to start a verification flow.
pub fn is_ready(&self) -> bool {
self.inner.is_ready()
}
/// Did we initiate the verification flow.
pub fn we_started(&self) -> bool {
self.inner.we_started()
}
/// Get the user id of the other user participating in this verification
/// flow.
pub fn other_user_id(&self) -> &ruma::UserId {
self.inner.other_user()
}
/// Is this a verification that is veryfying one of our own devices.
pub fn is_self_verification(&self) -> bool {
self.inner.is_self_verification()
}
/// Get the supported verification methods of the other side.
///
/// Will be present only if the other side requested the verification or if
/// we're in the ready state.
pub fn their_supported_methods(&self) -> Option<Vec<VerificationMethod>> {
self.inner.their_supported_methods()
}
/// Accept the verification request.
///
/// This method will accept the request and signal that it supports the
/// `m.sas.v1`, the `m.qr_code.show.v1`, and `m.reciprocate.v1` method.
///
/// If QR code scanning should be supported or QR code showing shouldn't be
/// supported the [`accept_with_methods()`] method should be used instead.
///
/// [`accept_with_methods()`]: #method.accept_with_methods
pub async fn accept(&self) -> Result<()> {
if let Some(request) = self.inner.accept() {
self.client.send_verification_request(request).await?;
}
Ok(())
}
/// Accept the verification request signaling that our client supports the
/// given verification methods.
///
/// # Arguments
///
/// * `methods` - The methods that we should advertise as supported by us.
pub async fn accept_with_methods(&self, methods: Vec<VerificationMethod>) -> Result<()> {
if let Some(request) = self.inner.accept_with_methods(methods) {
self.client.send_verification_request(request).await?;
}
Ok(())
}
/// Generate a QR code
pub async fn generate_qr_code(&self) -> Result<Option<QrVerification>> {
Ok(self
.inner
.generate_qr_code()
.await?
.map(|qr| QrVerification { inner: qr, client: self.client.clone() }))
}
/// Transition from this verification request into a SAS verification flow.
pub async fn start_sas(&self) -> Result<Option<SasVerification>> {
if let Some((sas, request)) = self.inner.start_sas().await? {
self.client.send_verification_request(request).await?;
Ok(Some(SasVerification { inner: sas, client: self.client.clone() }))
} else {
Ok(None)
}
}
/// Cancel the verification request
pub async fn cancel(&self) -> Result<()> {
if let Some(request) = self.inner.cancel() {
self.client.send_verification_request(request).await?;
}
Ok(())
}
}

View File

@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use matrix_sdk_base::crypto::{ use matrix_sdk_base::crypto::{AcceptSettings, CancelInfo, ReadOnlyDevice, Sas as BaseSas};
AcceptSettings, OutgoingVerificationRequest, ReadOnlyDevice, Sas as BaseSas, use ruma::UserId;
};
use crate::{error::Result, Client}; use crate::{error::Result, Client};
/// An object controlling the interactive verification flow. /// An object controlling the interactive verification flow.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Sas { pub struct SasVerification {
pub(crate) inner: BaseSas, pub(crate) inner: BaseSas,
pub(crate) client: Client, pub(crate) client: Client,
} }
impl Sas { impl SasVerification {
/// Accept the interactive verification flow. /// Accept the interactive verification flow.
pub async fn accept(&self) -> Result<()> { pub async fn accept(&self) -> Result<()> {
self.accept_with_settings(Default::default()).await self.accept_with_settings(Default::default()).await
@ -43,14 +42,21 @@ impl Sas {
/// # use matrix_sdk::Client; /// # use matrix_sdk::Client;
/// # use futures::executor::block_on; /// # use futures::executor::block_on;
/// # use url::Url; /// # use url::Url;
/// use matrix_sdk::Sas; /// # use ruma::user_id;
/// use matrix_sdk::verification::SasVerification;
/// use matrix_sdk_base::crypto::AcceptSettings; /// use matrix_sdk_base::crypto::AcceptSettings;
/// use matrix_sdk::events::key::verification::ShortAuthenticationString; /// use matrix_sdk::ruma::events::key::verification::ShortAuthenticationString;
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap(); /// # let client = Client::new(homeserver).unwrap();
/// # let flow_id = "someID"; /// # let flow_id = "someID";
/// # let user_id = user_id!("@alice:example");
/// # block_on(async { /// # block_on(async {
/// let sas = client.get_verification(flow_id).await.unwrap(); /// let sas = client
/// .get_verification(&user_id, flow_id)
/// .await
/// .unwrap()
/// .sas()
/// .unwrap();
/// ///
/// let only_decimal = AcceptSettings::with_allowed_methods( /// let only_decimal = AcceptSettings::with_allowed_methods(
/// vec![ShortAuthenticationString::Decimal] /// vec![ShortAuthenticationString::Decimal]
@ -59,15 +65,8 @@ impl Sas {
/// # }); /// # });
/// ``` /// ```
pub async fn accept_with_settings(&self, settings: AcceptSettings) -> Result<()> { pub async fn accept_with_settings(&self, settings: AcceptSettings) -> Result<()> {
if let Some(req) = self.inner.accept_with_settings(settings) { if let Some(request) = self.inner.accept_with_settings(settings) {
match req { self.client.send_verification_request(request).await?;
OutgoingVerificationRequest::ToDevice(r) => {
self.client.send_to_device(&r).await?;
}
OutgoingVerificationRequest::InRoom(r) => {
self.client.room_send_helper(&r).await?;
}
}
} }
Ok(()) Ok(())
} }
@ -76,16 +75,8 @@ impl Sas {
pub async fn confirm(&self) -> Result<()> { pub async fn confirm(&self) -> Result<()> {
let (request, signature) = self.inner.confirm().await?; let (request, signature) = self.inner.confirm().await?;
match request { if let Some(request) = request {
Some(OutgoingVerificationRequest::InRoom(r)) => { self.client.send_verification_request(request).await?;
self.client.room_send_helper(&r).await?;
}
Some(OutgoingVerificationRequest::ToDevice(r)) => {
self.client.send_to_device(&r).await?;
}
None => (),
} }
if let Some(s) = signature { if let Some(s) = signature {
@ -98,14 +89,7 @@ impl Sas {
/// Cancel the interactive verification flow. /// Cancel the interactive verification flow.
pub async fn cancel(&self) -> Result<()> { pub async fn cancel(&self) -> Result<()> {
if let Some(request) = self.inner.cancel() { if let Some(request) = self.inner.cancel() {
match request { self.client.send_verification_request(request).await?;
OutgoingVerificationRequest::ToDevice(r) => {
self.client.send_to_device(&r).await?;
}
OutgoingVerificationRequest::InRoom(r) => {
self.client.room_send_helper(&r).await?;
}
}
} }
Ok(()) Ok(())
@ -132,6 +116,22 @@ impl Sas {
self.inner.is_done() self.inner.is_done()
} }
/// Are we in a state where we can show the short auth string.
pub fn can_be_presented(&self) -> bool {
self.inner.can_be_presented()
}
/// Did we initiate the verification flow.
pub fn we_started(&self) -> bool {
self.inner.we_started()
}
/// Get info about the cancellation if the verification flow has been
/// cancelled.
pub fn cancel_info(&self) -> Option<CancelInfo> {
self.inner.cancel_info()
}
/// Is the verification process canceled. /// Is the verification process canceled.
pub fn is_cancelled(&self) -> bool { pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled() self.inner.is_cancelled()
@ -141,4 +141,25 @@ impl Sas {
pub fn other_device(&self) -> &ReadOnlyDevice { pub fn other_device(&self) -> &ReadOnlyDevice {
self.inner.other_device() self.inner.other_device()
} }
/// Did this verification flow start from a verification request.
pub fn started_from_request(&self) -> bool {
self.inner.started_from_request()
}
/// Is this a verification that is veryfying one of our own devices.
pub fn is_self_verification(&self) -> bool {
self.inner.is_self_verification()
}
/// Get our own user id.
pub fn own_user_id(&self) -> &UserId {
self.inner.user_id()
}
/// Get the user id of the other user participating in this verification
/// flow.
pub fn other_user_id(&self) -> &UserId {
self.inner.other_user_id()
}
} }

View File

@ -1,49 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_base::crypto::{
OutgoingVerificationRequest, VerificationRequest as BaseVerificationRequest,
};
use crate::{Client, Result};
/// An object controlling the interactive verification flow.
#[derive(Debug, Clone)]
pub struct VerificationRequest {
pub(crate) inner: BaseVerificationRequest,
pub(crate) client: Client,
}
impl VerificationRequest {
/// Accept the verification request
pub async fn accept(&self) -> Result<()> {
if let Some(request) = self.inner.accept() {
match request {
OutgoingVerificationRequest::ToDevice(r) => {
self.client.send_to_device(&r).await?;
}
OutgoingVerificationRequest::InRoom(r) => {
self.client.room_send_helper(&r).await?;
}
};
}
Ok(())
}
/// Cancel the verification request
pub async fn cancel(&self) -> Result<()> {
todo!()
}
}

View File

@ -9,13 +9,10 @@ version = "0.1.0"
[features] [features]
default = ["warp"] default = ["warp"]
actix = ["actix-rt", "actix-web"]
docs = ["actix", "warp"] docs = ["warp"]
[dependencies] [dependencies]
actix-rt = { version = "2", optional = true }
actix-web = { version = "4.0.0-beta.6", optional = true }
dashmap = "4" dashmap = "4"
futures = "0.3" futures = "0.3"
futures-util = "0.3" futures-util = "0.3"
@ -29,10 +26,10 @@ tracing = "0.1"
url = "2" url = "2"
warp = { git = "https://github.com/seanmonstar/warp.git", rev = "629405", optional = true, default-features = false } warp = { git = "https://github.com/seanmonstar/warp.git", rev = "629405", optional = true, default-features = false }
matrix-sdk = { version = "0.2", path = "../matrix_sdk", default-features = false, features = ["appservice", "native-tls"] } matrix-sdk = { version = "0.3", path = "../matrix_sdk", default-features = false, features = ["appservice", "native-tls"] }
[dependencies.ruma] [dependencies.ruma]
version = "0.1.2" version = "0.2.0"
features = ["client-api-c", "appservice-api-s", "unstable-pre-spec"] features = ["client-api-c", "appservice-api-s", "unstable-pre-spec"]
[dev-dependencies] [dev-dependencies]
@ -41,7 +38,7 @@ mockito = "0.30"
tokio = { version = "1", default-features = false, features = ["rt-multi-thread", "macros"] } tokio = { version = "1", default-features = false, features = ["rt-multi-thread", "macros"] }
tracing-subscriber = "0.2" tracing-subscriber = "0.2"
matrix-sdk-test = { version = "0.2", path = "../matrix_sdk_test", features = ["appservice"] } matrix-sdk-test = { version = "0.3", path = "../matrix_sdk_test", features = ["appservice"] }
[[example]] [[example]]
name = "appservice_autojoin" name = "appservice_autojoin"

View File

@ -3,24 +3,26 @@ use std::{convert::TryFrom, env};
use matrix_sdk_appservice::{ use matrix_sdk_appservice::{
matrix_sdk::{ matrix_sdk::{
async_trait, async_trait,
events::{
room::member::{MemberEventContent, MembershipState},
SyncStateEvent,
},
identifiers::UserId,
room::Room, room::Room,
ruma::{
events::{
room::member::{MemberEventContent, MembershipState},
SyncStateEvent,
},
UserId,
},
EventHandler, EventHandler,
}, },
Appservice, AppserviceRegistration, AppService, AppServiceRegistration,
}; };
use tracing::{error, trace}; use tracing::{error, trace};
struct AppserviceEventHandler { struct AppServiceEventHandler {
appservice: Appservice, appservice: AppService,
} }
impl AppserviceEventHandler { impl AppServiceEventHandler {
pub fn new(appservice: Appservice) -> Self { pub fn new(appservice: AppService) -> Self {
Self { appservice } Self { appservice }
} }
@ -47,7 +49,7 @@ impl AppserviceEventHandler {
} }
#[async_trait] #[async_trait]
impl EventHandler for AppserviceEventHandler { impl EventHandler for AppServiceEventHandler {
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) { async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
match self.handle_room_member(room, event).await { match self.handle_room_member(room, event).await {
Ok(_) => (), Ok(_) => (),
@ -63,10 +65,10 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
let homeserver_url = "http://localhost:8008"; let homeserver_url = "http://localhost:8008";
let server_name = "localhost"; let server_name = "localhost";
let registration = AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml")?; let registration = AppServiceRegistration::try_from_yaml_file("./tests/registration.yaml")?;
let mut appservice = Appservice::new(homeserver_url, server_name, registration).await?; let mut appservice = AppService::new(homeserver_url, server_name, registration).await?;
appservice.set_event_handler(Box::new(AppserviceEventHandler::new(appservice.clone()))).await?; appservice.set_event_handler(Box::new(AppServiceEventHandler::new(appservice.clone()))).await?;
let (host, port) = appservice.registration().get_host_and_port()?; let (host, port) = appservice.registration().get_host_and_port()?;
appservice.run(host, port).await?; appservice.run(host, port).await?;

View File

@ -70,19 +70,8 @@ pub enum Error {
#[cfg(feature = "warp")] #[cfg(feature = "warp")]
#[error("warp rejection: {0}")] #[error("warp rejection: {0}")]
WarpRejection(String), WarpRejection(String),
#[cfg(feature = "actix")]
#[error(transparent)]
Actix(#[from] actix_web::Error),
#[cfg(feature = "actix")]
#[error(transparent)]
ActixPayload(#[from] actix_web::error::PayloadError),
} }
#[cfg(feature = "actix")]
impl actix_web::error::ResponseError for Error {}
#[cfg(feature = "warp")] #[cfg(feature = "warp")]
impl warp::reject::Reject for Error {} impl warp::reject::Reject for Error {}

View File

@ -41,11 +41,11 @@
//! # #[async_trait] //! # #[async_trait]
//! # impl EventHandler for MyEventHandler {} //! # impl EventHandler for MyEventHandler {}
//! # //! #
//! use matrix_sdk_appservice::{Appservice, AppserviceRegistration}; //! use matrix_sdk_appservice::{AppService, AppServiceRegistration};
//! //!
//! let homeserver_url = "http://127.0.0.1:8008"; //! let homeserver_url = "http://127.0.0.1:8008";
//! let server_name = "localhost"; //! let server_name = "localhost";
//! let registration = AppserviceRegistration::try_from_yaml_str( //! let registration = AppServiceRegistration::try_from_yaml_str(
//! r" //! r"
//! id: appservice //! id: appservice
//! url: http://127.0.0.1:9009 //! url: http://127.0.0.1:9009
@ -58,7 +58,7 @@
//! regex: '@_appservice_.*' //! regex: '@_appservice_.*'
//! ")?; //! ")?;
//! //!
//! let mut appservice = Appservice::new(homeserver_url, server_name, registration).await?; //! let mut appservice = AppService::new(homeserver_url, server_name, registration).await?;
//! appservice.set_event_handler(Box::new(MyEventHandler)).await?; //! appservice.set_event_handler(Box::new(MyEventHandler)).await?;
//! //!
//! let (host, port) = appservice.registration().get_host_and_port()?; //! let (host, port) = appservice.registration().get_host_and_port()?;
@ -74,8 +74,8 @@
//! [matrix-org/matrix-rust-sdk#228]: https://github.com/matrix-org/matrix-rust-sdk/issues/228 //! [matrix-org/matrix-rust-sdk#228]: https://github.com/matrix-org/matrix-rust-sdk/issues/228
//! [examples directory]: https://github.com/matrix-org/matrix-rust-sdk/tree/master/matrix_sdk_appservice/examples //! [examples directory]: https://github.com/matrix-org/matrix-rust-sdk/tree/master/matrix_sdk_appservice/examples
#[cfg(not(any(feature = "actix", feature = "warp")))] #[cfg(not(any(feature = "warp")))]
compile_error!("one webserver feature must be enabled. available ones: `actix`, `warp`"); compile_error!("one webserver feature must be enabled. available ones: `warp`");
use std::{ use std::{
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
@ -89,12 +89,15 @@ use dashmap::DashMap;
pub use error::Error; pub use error::Error;
use http::{uri::PathAndQuery, Uri}; use http::{uri::PathAndQuery, Uri};
pub use matrix_sdk; pub use matrix_sdk;
use matrix_sdk::{reqwest::Url, Bytes, Client, ClientConfig, EventHandler, HttpError, Session}; #[doc(no_inline)]
pub use matrix_sdk::ruma;
use matrix_sdk::{
bytes::Bytes, reqwest::Url, Client, ClientConfig, EventHandler, HttpError, Session,
};
use regex::Regex; use regex::Regex;
#[doc(inline)]
pub use ruma::api::{appservice as api, appservice::Registration};
use ruma::{ use ruma::{
api::{ api::{
appservice::Registration,
client::{ client::{
error::ErrorKind, error::ErrorKind,
r0::{account::register, uiaa::UiaaResponse}, r0::{account::register, uiaa::UiaaResponse},
@ -112,15 +115,15 @@ pub type Result<T> = std::result::Result<T, Error>;
pub type Host = String; pub type Host = String;
pub type Port = u16; pub type Port = u16;
/// Appservice Registration /// AppService Registration
/// ///
/// Wrapper around [`Registration`] /// Wrapper around [`Registration`]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AppserviceRegistration { pub struct AppServiceRegistration {
inner: Registration, inner: Registration,
} }
impl AppserviceRegistration { impl AppServiceRegistration {
/// Try to load registration from yaml string /// Try to load registration from yaml string
/// ///
/// See the fields of [`Registration`] for the required format /// See the fields of [`Registration`] for the required format
@ -158,13 +161,13 @@ impl AppserviceRegistration {
} }
} }
impl From<Registration> for AppserviceRegistration { impl From<Registration> for AppServiceRegistration {
fn from(value: Registration) -> Self { fn from(value: Registration) -> Self {
Self { inner: value } Self { inner: value }
} }
} }
impl Deref for AppserviceRegistration { impl Deref for AppServiceRegistration {
type Target = Registration; type Target = Registration;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -175,7 +178,7 @@ impl Deref for AppserviceRegistration {
type Localpart = String; type Localpart = String;
/// The `localpart` of the user associated with the application service via /// The `localpart` of the user associated with the application service via
/// `sender_localpart` in [`AppserviceRegistration`]. /// `sender_localpart` in [`AppServiceRegistration`].
/// ///
/// Dummy type for shared documentation /// Dummy type for shared documentation
#[allow(dead_code)] #[allow(dead_code)]
@ -183,23 +186,23 @@ pub type MainUser = ();
/// The application service may specify the virtual user to act as through use /// The application service may specify the virtual user to act as through use
/// of a user_id query string parameter on the request. The user specified in /// of a user_id query string parameter on the request. The user specified in
/// the query string must be covered by one of the [`AppserviceRegistration`]'s /// the query string must be covered by one of the [`AppServiceRegistration`]'s
/// `users` namespaces. /// `users` namespaces.
/// ///
/// Dummy type for shared documentation /// Dummy type for shared documentation
pub type VirtualUser = (); pub type VirtualUser = ();
/// Appservice /// AppService
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Appservice { pub struct AppService {
homeserver_url: Url, homeserver_url: Url,
server_name: ServerNameBox, server_name: ServerNameBox,
registration: Arc<AppserviceRegistration>, registration: Arc<AppServiceRegistration>,
clients: Arc<DashMap<Localpart, Client>>, clients: Arc<DashMap<Localpart, Client>>,
} }
impl Appservice { impl AppService {
/// Create new Appservice /// Create new AppService
/// ///
/// Also creates and caches a [`Client`] for the [`MainUser`]. /// Also creates and caches a [`Client`] for the [`MainUser`].
/// The default [`ClientConfig`] is used, if you want to customize it /// The default [`ClientConfig`] is used, if you want to customize it
@ -210,14 +213,14 @@ impl Appservice {
/// * `homeserver_url` - The homeserver that the client should connect to. /// * `homeserver_url` - The homeserver that the client should connect to.
/// * `server_name` - The server name to use when constructing user ids from /// * `server_name` - The server name to use when constructing user ids from
/// the localpart. /// the localpart.
/// * `registration` - The [Appservice Registration] to use when interacting /// * `registration` - The [AppService Registration] to use when interacting
/// with the homeserver. /// with the homeserver.
/// ///
/// [Appservice Registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration /// [AppService Registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
pub async fn new( pub async fn new(
homeserver_url: impl TryInto<Url, Error = url::ParseError>, homeserver_url: impl TryInto<Url, Error = url::ParseError>,
server_name: impl TryInto<ServerNameBox, Error = identifiers::Error>, server_name: impl TryInto<ServerNameBox, Error = identifiers::Error>,
registration: AppserviceRegistration, registration: AppServiceRegistration,
) -> Result<Self> { ) -> Result<Self> {
let appservice = Self::new_with_config( let appservice = Self::new_with_config(
homeserver_url, homeserver_url,
@ -235,7 +238,7 @@ impl Appservice {
pub async fn new_with_config( pub async fn new_with_config(
homeserver_url: impl TryInto<Url, Error = url::ParseError>, homeserver_url: impl TryInto<Url, Error = url::ParseError>,
server_name: impl TryInto<ServerNameBox, Error = identifiers::Error>, server_name: impl TryInto<ServerNameBox, Error = identifiers::Error>,
registration: AppserviceRegistration, registration: AppServiceRegistration,
client_config: ClientConfig, client_config: ClientConfig,
) -> Result<Self> { ) -> Result<Self> {
let homeserver_url = homeserver_url.try_into()?; let homeserver_url = homeserver_url.try_into()?;
@ -244,7 +247,7 @@ impl Appservice {
let clients = Arc::new(DashMap::new()); let clients = Arc::new(DashMap::new());
let sender_localpart = registration.sender_localpart.clone(); let sender_localpart = registration.sender_localpart.clone();
let appservice = Appservice { homeserver_url, server_name, registration, clients }; let appservice = AppService { homeserver_url, server_name, registration, clients };
// we create and cache the [`MainUser`] by default // we create and cache the [`MainUser`] by default
appservice.create_and_cache_client(&sender_localpart, client_config).await?; appservice.create_and_cache_client(&sender_localpart, client_config).await?;
@ -354,12 +357,12 @@ impl Appservice {
/// Convenience wrapper around [`Client::set_event_handler()`] that attaches /// Convenience wrapper around [`Client::set_event_handler()`] that attaches
/// the event handler to the [`MainUser`]'s [`Client`] /// the event handler to the [`MainUser`]'s [`Client`]
/// ///
/// Note that the event handler in the [`Appservice`] context only triggers /// Note that the event handler in the [`AppService`] context only triggers
/// [`join` room `timeline` events], so no state events or events from the /// [`join` room `timeline` events], so no state events or events from the
/// `invite`, `knock` or `leave` scope. The rationale behind that is /// `invite`, `knock` or `leave` scope. The rationale behind that is
/// that incoming Appservice transactions from the homeserver are not /// that incoming AppService transactions from the homeserver are not
/// necessarily bound to a specific user but can cover a multitude of /// necessarily bound to a specific user but can cover a multitude of
/// namespaces, and as such the Appservice basically only "observes /// namespaces, and as such the AppService basically only "observes
/// joined rooms". Also currently homeservers only push PDUs to appservices, /// joined rooms". Also currently homeservers only push PDUs to appservices,
/// no EDUs. There's the open [MSC2409] regarding supporting EDUs in the /// no EDUs. There's the open [MSC2409] regarding supporting EDUs in the
/// future, though it seems to be planned to put EDUs into a different /// future, though it seems to be planned to put EDUs into a different
@ -410,10 +413,10 @@ impl Appservice {
Ok(()) Ok(())
} }
/// Get the Appservice [registration] /// Get the AppService [registration]
/// ///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration /// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
pub fn registration(&self) -> &AppserviceRegistration { pub fn registration(&self) -> &AppServiceRegistration {
&self.registration &self.registration
} }
@ -424,11 +427,11 @@ impl Appservice {
self.registration.hs_token == hs_token.as_ref() self.registration.hs_token == hs_token.as_ref()
} }
/// Check if given `user_id` is in any of the [`AppserviceRegistration`]'s /// Check if given `user_id` is in any of the [`AppServiceRegistration`]'s
/// `users` namespaces /// `users` namespaces
pub fn user_id_is_in_namespace(&self, user_id: impl AsRef<str>) -> Result<bool> { pub fn user_id_is_in_namespace(&self, user_id: impl AsRef<str>) -> Result<bool> {
for user in &self.registration.namespaces.users { for user in &self.registration.namespaces.users {
// TODO: precompile on Appservice construction // TODO: precompile on AppService construction
let re = Regex::new(&user.regex)?; let re = Regex::new(&user.regex)?;
if re.is_match(user_id.as_ref()) { if re.is_match(user_id.as_ref()) {
return Ok(true); return Ok(true);
@ -438,24 +441,6 @@ impl Appservice {
Ok(false) Ok(false)
} }
/// Returns a closure to be used with [`actix_web::App::configure()`]
///
/// Note that if you handle any of the [application-service-specific
/// routes], including the legacy routes, you will break the appservice
/// functionality.
///
/// [application-service-specific routes]: https://spec.matrix.org/unstable/application-service-api/#legacy-routes
#[cfg(feature = "actix")]
#[cfg_attr(docs, doc(cfg(feature = "actix")))]
pub fn actix_configure(&self) -> impl FnOnce(&mut actix_web::web::ServiceConfig) {
let appservice = self.clone();
move |config| {
config.data(appservice);
webserver::actix::configure(config);
}
}
/// Returns a [`warp::Filter`] to be used as [`warp::serve()`] route /// Returns a [`warp::Filter`] to be used as [`warp::serve()`] route
/// ///
/// Note that if you handle any of the [application-service-specific /// Note that if you handle any of the [application-service-specific
@ -477,13 +462,7 @@ impl Appservice {
pub async fn run(&self, host: impl Into<String>, port: impl Into<u16>) -> Result<()> { pub async fn run(&self, host: impl Into<String>, port: impl Into<u16>) -> Result<()> {
let host = host.into(); let host = host.into();
let port = port.into(); let port = port.into();
info!("Starting Appservice on {}:{}", &host, &port); info!("Starting AppService on {}:{}", &host, &port);
#[cfg(feature = "actix")]
{
webserver::actix::run_server(self.clone(), host, port).await?;
Ok(())
}
#[cfg(feature = "warp")] #[cfg(feature = "warp")]
{ {
@ -491,7 +470,7 @@ impl Appservice {
Ok(()) Ok(())
} }
#[cfg(not(any(feature = "actix", feature = "warp",)))] #[cfg(not(any(feature = "warp",)))]
unreachable!() unreachable!()
} }
} }

View File

@ -1,149 +0,0 @@
// Copyright 2021 Famedly GmbH
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin;
pub use actix_web::Scope;
use actix_web::{
dev::Payload,
error::PayloadError,
get, put,
web::{self, BytesMut, Data},
App, FromRequest, HttpRequest, HttpResponse, HttpServer,
};
use futures::Future;
use futures_util::TryStreamExt;
use ruma::api::appservice as api;
use crate::{error::Error, Appservice};
pub async fn run_server(
appservice: Appservice,
host: impl Into<String>,
port: impl Into<u16>,
) -> Result<(), Error> {
HttpServer::new(move || App::new().configure(appservice.actix_configure()))
.bind((host.into(), port.into()))?
.run()
.await?;
Ok(())
}
pub fn configure(config: &mut actix_web::web::ServiceConfig) {
// also handles legacy routes
config.service(push_transactions).service(query_user_id).service(query_room_alias).service(
web::scope("/_matrix/app/v1")
.service(push_transactions)
.service(query_user_id)
.service(query_room_alias),
);
}
#[tracing::instrument]
#[put("/transactions/{txn_id}")]
async fn push_transactions(
request: IncomingRequest<api::event::push_events::v1::IncomingRequest>,
appservice: Data<Appservice>,
) -> Result<HttpResponse, Error> {
if !appservice.compare_hs_token(request.access_token) {
return Ok(HttpResponse::Unauthorized().finish());
}
appservice.get_cached_client(None)?.receive_transaction(request.incoming).await?;
Ok(HttpResponse::Ok().json("{}"))
}
#[tracing::instrument]
#[get("/users/{user_id}")]
async fn query_user_id(
request: IncomingRequest<api::query::query_user_id::v1::IncomingRequest>,
appservice: Data<Appservice>,
) -> Result<HttpResponse, Error> {
if !appservice.compare_hs_token(request.access_token) {
return Ok(HttpResponse::Unauthorized().finish());
}
Ok(HttpResponse::Ok().json("{}"))
}
#[tracing::instrument]
#[get("/rooms/{room_alias}")]
async fn query_room_alias(
request: IncomingRequest<api::query::query_room_alias::v1::IncomingRequest>,
appservice: Data<Appservice>,
) -> Result<HttpResponse, Error> {
if !appservice.compare_hs_token(request.access_token) {
return Ok(HttpResponse::Unauthorized().finish());
}
Ok(HttpResponse::Ok().json("{}"))
}
#[derive(Debug)]
pub struct IncomingRequest<T> {
access_token: String,
incoming: T,
}
impl<T: ruma::api::IncomingRequest> FromRequest for IncomingRequest<T> {
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
type Config = ();
fn from_request(request: &HttpRequest, payload: &mut Payload) -> Self::Future {
let request = request.to_owned();
let payload = payload.take();
Box::pin(async move {
let mut builder =
http::request::Builder::new().method(request.method()).uri(request.uri());
let headers = builder.headers_mut().ok_or(Error::UnknownHttpRequestBuilder)?;
for (key, value) in request.headers().iter() {
headers.append(key, value.to_owned());
}
let bytes = payload
.try_fold(BytesMut::new(), |mut body, chunk| async move {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
.await?
.into();
let access_token = match request.uri().query() {
Some(query) => {
let query: Vec<(String, String)> = ruma::serde::urlencoded::from_str(query)?;
query.into_iter().find(|(key, _)| key == "access_token").map(|(_, value)| value)
}
None => None,
};
let access_token = match access_token {
Some(access_token) => access_token,
None => return Err(Error::MissingAccessToken),
};
let request = builder.body(bytes)?;
let request = crate::transform_legacy_route(request)?;
Ok(IncomingRequest {
access_token,
incoming: ruma::api::IncomingRequest::try_from_http_request(request)?,
})
})
}
}

View File

@ -1,4 +1,2 @@
#[cfg(feature = "actix")]
pub mod actix;
#[cfg(feature = "warp")] #[cfg(feature = "warp")]
pub mod warp; pub mod warp;

View File

@ -15,14 +15,14 @@
use std::{net::ToSocketAddrs, result::Result as StdResult}; use std::{net::ToSocketAddrs, result::Result as StdResult};
use futures::TryFutureExt; use futures::TryFutureExt;
use matrix_sdk::Bytes; use matrix_sdk::{bytes::Bytes, ruma};
use serde::Serialize; use serde::Serialize;
use warp::{filters::BoxedFilter, path::FullPath, Filter, Rejection, Reply}; use warp::{filters::BoxedFilter, path::FullPath, Filter, Rejection, Reply};
use crate::{Appservice, Error, Result}; use crate::{AppService, Error, Result};
pub async fn run_server( pub async fn run_server(
appservice: Appservice, appservice: AppService,
host: impl Into<String>, host: impl Into<String>,
port: impl Into<u16>, port: impl Into<u16>,
) -> Result<()> { ) -> Result<()> {
@ -37,7 +37,7 @@ pub async fn run_server(
} }
} }
pub fn warp_filter(appservice: Appservice) -> BoxedFilter<(impl Reply,)> { pub fn warp_filter(appservice: AppService) -> BoxedFilter<(impl Reply,)> {
// TODO: try to use a struct instead of needlessly cloning appservice multiple // TODO: try to use a struct instead of needlessly cloning appservice multiple
// times on every request // times on every request
warp::any() warp::any()
@ -51,7 +51,7 @@ pub fn warp_filter(appservice: Appservice) -> BoxedFilter<(impl Reply,)> {
mod filters { mod filters {
use super::*; use super::*;
pub fn users(appservice: Appservice) -> BoxedFilter<(impl Reply,)> { pub fn users(appservice: AppService) -> BoxedFilter<(impl Reply,)> {
warp::get() warp::get()
.and( .and(
warp::path!("_matrix" / "app" / "v1" / "users" / String) warp::path!("_matrix" / "app" / "v1" / "users" / String)
@ -65,7 +65,7 @@ mod filters {
.boxed() .boxed()
} }
pub fn rooms(appservice: Appservice) -> BoxedFilter<(impl Reply,)> { pub fn rooms(appservice: AppService) -> BoxedFilter<(impl Reply,)> {
warp::get() warp::get()
.and( .and(
warp::path!("_matrix" / "app" / "v1" / "rooms" / String) warp::path!("_matrix" / "app" / "v1" / "rooms" / String)
@ -79,7 +79,7 @@ mod filters {
.boxed() .boxed()
} }
pub fn transactions(appservice: Appservice) -> BoxedFilter<(impl Reply,)> { pub fn transactions(appservice: AppService) -> BoxedFilter<(impl Reply,)> {
warp::put() warp::put()
.and( .and(
warp::path!("_matrix" / "app" / "v1" / "transactions" / String) warp::path!("_matrix" / "app" / "v1" / "transactions" / String)
@ -93,7 +93,7 @@ mod filters {
.boxed() .boxed()
} }
fn common(appservice: Appservice) -> BoxedFilter<(Appservice, http::Request<Bytes>)> { fn common(appservice: AppService) -> BoxedFilter<(AppService, http::Request<Bytes>)> {
warp::any() warp::any()
.and(filters::valid_access_token(appservice.registration().hs_token.clone())) .and(filters::valid_access_token(appservice.registration().hs_token.clone()))
.map(move || appservice.clone()) .map(move || appservice.clone())
@ -110,7 +110,7 @@ mod filters {
.and(warp::query::raw()) .and(warp::query::raw())
.and_then(|token: String, query: String| async move { .and_then(|token: String, query: String| async move {
let query: Vec<(String, String)> = let query: Vec<(String, String)> =
matrix_sdk::urlencoded::from_str(&query).map_err(Error::from)?; ruma::serde::urlencoded::from_str(&query).map_err(Error::from)?;
if query.into_iter().any(|(key, value)| key == "access_token" && value == token) { if query.into_iter().any(|(key, value)| key == "access_token" && value == token) {
Ok::<(), Rejection>(()) Ok::<(), Rejection>(())
@ -156,7 +156,7 @@ mod handlers {
pub async fn user( pub async fn user(
_user_id: String, _user_id: String,
_appservice: Appservice, _appservice: AppService,
_request: http::Request<Bytes>, _request: http::Request<Bytes>,
) -> StdResult<impl warp::Reply, Rejection> { ) -> StdResult<impl warp::Reply, Rejection> {
Ok(warp::reply::json(&String::from("{}"))) Ok(warp::reply::json(&String::from("{}")))
@ -164,7 +164,7 @@ mod handlers {
pub async fn room( pub async fn room(
_room_id: String, _room_id: String,
_appservice: Appservice, _appservice: AppService,
_request: http::Request<Bytes>, _request: http::Request<Bytes>,
) -> StdResult<impl warp::Reply, Rejection> { ) -> StdResult<impl warp::Reply, Rejection> {
Ok(warp::reply::json(&String::from("{}"))) Ok(warp::reply::json(&String::from("{}")))
@ -172,11 +172,11 @@ mod handlers {
pub async fn transaction( pub async fn transaction(
_txn_id: String, _txn_id: String,
appservice: Appservice, appservice: AppService,
request: http::Request<Bytes>, request: http::Request<Bytes>,
) -> StdResult<impl warp::Reply, Rejection> { ) -> StdResult<impl warp::Reply, Rejection> {
let incoming_transaction: matrix_sdk::api_appservice::event::push_events::v1::IncomingRequest = let incoming_transaction: ruma::api::appservice::event::push_events::v1::IncomingRequest =
matrix_sdk::IncomingRequest::try_from_http_request(request).map_err(Error::from)?; ruma::api::IncomingRequest::try_from_http_request(request).map_err(Error::from)?;
let client = appservice.get_cached_client(None)?; let client = appservice.get_cached_client(None)?;
client.receive_transaction(incoming_transaction).map_err(Error::from).await?; client.receive_transaction(incoming_transaction).map_err(Error::from).await?;

View File

@ -1,12 +1,12 @@
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
#[cfg(feature = "actix")]
use actix_web::{test as actix_test, App as ActixApp, HttpResponse};
use matrix_sdk::{ use matrix_sdk::{
api_appservice::Registration,
async_trait, async_trait,
events::{room::member::MemberEventContent, SyncStateEvent},
room::Room, room::Room,
ruma::{
api::appservice::Registration,
events::{room::member::MemberEventContent, SyncStateEvent},
},
ClientConfig, EventHandler, RequestConfig, ClientConfig, EventHandler, RequestConfig,
}; };
use matrix_sdk_appservice::*; use matrix_sdk_appservice::*;
@ -19,16 +19,16 @@ fn registration_string() -> String {
include_str!("../tests/registration.yaml").to_owned() include_str!("../tests/registration.yaml").to_owned()
} }
async fn appservice(registration: Option<Registration>) -> Result<Appservice> { async fn appservice(registration: Option<Registration>) -> Result<AppService> {
// env::set_var( // env::set_var(
// "RUST_LOG", // "RUST_LOG",
// "mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug,warp=debug", // "mockito=debug,matrix_sdk=debug,ruma=debug,warp=debug",
// ); // );
let _ = tracing_subscriber::fmt::try_init(); let _ = tracing_subscriber::fmt::try_init();
let registration = match registration { let registration = match registration {
Some(registration) => registration.into(), Some(registration) => registration.into(),
None => AppserviceRegistration::try_from_yaml_str(registration_string()).unwrap(), None => AppServiceRegistration::try_from_yaml_str(registration_string()).unwrap(),
}; };
let homeserver_url = mockito::server_url(); let homeserver_url = mockito::server_url();
@ -37,7 +37,7 @@ async fn appservice(registration: Option<Registration>) -> Result<Appservice> {
let client_config = let client_config =
ClientConfig::default().request_config(RequestConfig::default().disable_retry()); ClientConfig::default().request_config(RequestConfig::default().disable_retry());
Ok(Appservice::new_with_config( Ok(AppService::new_with_config(
homeserver_url.as_ref(), homeserver_url.as_ref(),
server_name, server_name,
registration, registration,
@ -97,16 +97,6 @@ async fn test_put_transaction() -> Result<()> {
.into_response() .into_response()
.status(); .status();
#[cfg(feature = "actix")]
let status = {
let app =
actix_test::init_service(ActixApp::new().configure(appservice.actix_configure())).await;
let req = actix_test::TestRequest::put().uri(uri).set_json(&transaction).to_request();
actix_test::call_service(&app, req).await.status()
};
assert_eq!(status, 200); assert_eq!(status, 200);
Ok(()) Ok(())
@ -128,16 +118,6 @@ async fn test_get_user() -> Result<()> {
.into_response() .into_response()
.status(); .status();
#[cfg(feature = "actix")]
let status = {
let app =
actix_test::init_service(ActixApp::new().configure(appservice.actix_configure())).await;
let req = actix_test::TestRequest::get().uri(uri).to_request();
actix_test::call_service(&app, req).await.status()
};
assert_eq!(status, 200); assert_eq!(status, 200);
Ok(()) Ok(())
@ -159,16 +139,6 @@ async fn test_get_room() -> Result<()> {
.into_response() .into_response()
.status(); .status();
#[cfg(feature = "actix")]
let status = {
let app =
actix_test::init_service(ActixApp::new().configure(appservice.actix_configure())).await;
let req = actix_test::TestRequest::get().uri(uri).to_request();
actix_test::call_service(&app, req).await.status()
};
assert_eq!(status, 200); assert_eq!(status, 200);
Ok(()) Ok(())
@ -195,16 +165,6 @@ async fn test_invalid_access_token() -> Result<()> {
.into_response() .into_response()
.status(); .status();
#[cfg(feature = "actix")]
let status = {
let app =
actix_test::init_service(ActixApp::new().configure(appservice.actix_configure())).await;
let req = actix_test::TestRequest::put().uri(uri).set_json(&transaction).to_request();
actix_test::call_service(&app, req).await.status()
};
assert_eq!(status, 401); assert_eq!(status, 401);
Ok(()) Ok(())
@ -235,20 +195,6 @@ async fn test_no_access_token() -> Result<()> {
assert_eq!(status, 401); assert_eq!(status, 401);
} }
#[cfg(feature = "actix")]
{
let app =
actix_test::init_service(ActixApp::new().configure(appservice.actix_configure())).await;
let req = actix_test::TestRequest::put().uri(uri).set_json(&transaction).to_request();
let resp = actix_test::call_service(&app, req).await;
// TODO: this should actually return a 401 but is 500 because something in the
// extractor fails
assert_eq!(resp.status(), 500);
}
Ok(()) Ok(())
} }
@ -294,16 +240,6 @@ async fn test_event_handler() -> Result<()> {
.await .await
.unwrap(); .unwrap();
#[cfg(feature = "actix")]
{
let app =
actix_test::init_service(ActixApp::new().configure(appservice.actix_configure())).await;
let req = actix_test::TestRequest::put().uri(uri).set_json(&transaction).to_request();
actix_test::call_service(&app, req).await;
};
let on_room_member_called = *example.on_state_member.lock().unwrap(); let on_room_member_called = *example.on_state_member.lock().unwrap();
assert!(on_room_member_called); assert!(on_room_member_called);
@ -330,20 +266,6 @@ async fn test_unrelated_path() -> Result<()> {
response.status() response.status()
}; };
#[cfg(feature = "actix")]
let status = {
let app = actix_test::init_service(
ActixApp::new()
.configure(appservice.actix_configure())
.route("/unrelated", actix_web::web::get().to(HttpResponse::Ok)),
)
.await;
let req = actix_test::TestRequest::get().uri("/unrelated").to_request();
actix_test::call_service(&app, req).await.status()
};
assert_eq!(status, 200); assert_eq!(status, 200);
Ok(()) Ok(())
@ -355,7 +277,7 @@ mod registration {
#[test] #[test]
fn test_registration() -> Result<()> { fn test_registration() -> Result<()> {
let registration: Registration = serde_yaml::from_str(&registration_string())?; let registration: Registration = serde_yaml::from_str(&registration_string())?;
let registration: AppserviceRegistration = registration.into(); let registration: AppServiceRegistration = registration.into();
assert_eq!(registration.id, "appservice"); assert_eq!(registration.id, "appservice");
@ -364,7 +286,7 @@ mod registration {
#[test] #[test]
fn test_registration_from_yaml_file() -> Result<()> { fn test_registration_from_yaml_file() -> Result<()> {
let registration = AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml")?; let registration = AppServiceRegistration::try_from_yaml_file("./tests/registration.yaml")?;
assert_eq!(registration.id, "appservice"); assert_eq!(registration.id, "appservice");
@ -373,7 +295,7 @@ mod registration {
#[test] #[test]
fn test_registration_from_yaml_str() -> Result<()> { fn test_registration_from_yaml_str() -> Result<()> {
let registration = AppserviceRegistration::try_from_yaml_str(registration_string())?; let registration = AppServiceRegistration::try_from_yaml_str(registration_string())?;
assert_eq!(registration.id, "appservice"); assert_eq!(registration.id, "appservice");

View File

@ -8,7 +8,7 @@ license = "Apache-2.0"
name = "matrix-sdk-base" name = "matrix-sdk-base"
readme = "README.md" readme = "README.md"
repository = "https://github.com/matrix-org/matrix-rust-sdk" repository = "https://github.com/matrix-org/matrix-rust-sdk"
version = "0.2.0" version = "0.3.0"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["docs"] features = ["docs"]
@ -26,44 +26,44 @@ docs = ["encryption", "sled_cryptostore"]
[dependencies] [dependencies]
dashmap = "4.0.2" dashmap = "4.0.2"
lru = "0.6.5" lru = "0.6.5"
ruma = { version = "0.1.2", features = ["client-api-c", "unstable-pre-spec"] } ruma = { version = "0.2.0", features = ["client-api-c", "unstable-pre-spec"] }
serde = { version = "1.0.122", features = ["rc"] } serde = { version = "1.0.126", features = ["rc"] }
serde_json = "1.0.61" serde_json = "1.0.64"
tracing = "0.1.22" tracing = "0.1.26"
matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" } matrix-sdk-common = { version = "0.3.0", path = "../matrix_sdk_common" }
matrix-sdk-crypto = { version = "0.2.0", path = "../matrix_sdk_crypto", optional = true } matrix-sdk-crypto = { version = "0.3.0", path = "../matrix_sdk_crypto", optional = true }
# Misc dependencies # Misc dependencies
thiserror = "1.0.23" thiserror = "1.0.25"
futures = "0.3.12" futures = "0.3.15"
zeroize = { version = "1.2.0", features = ["zeroize_derive"] } zeroize = { version = "1.3.0", features = ["zeroize_derive"] }
# Deps for the sled state store # Deps for the sled state store
sled = { version = "0.34.6", optional = true } sled = { version = "0.34.6", optional = true }
chacha20poly1305 = { version = "0.7.1", optional = true } chacha20poly1305 = { version = "0.8.0", optional = true }
pbkdf2 = { version = "0.6.0", default-features = false, optional = true } pbkdf2 = { version = "0.8.0", default-features = false, optional = true }
hmac = { version = "0.10.1", optional = true } hmac = { version = "0.11.0", optional = true }
sha2 = { version = "0.9.2", optional = true } sha2 = { version = "0.9.5", optional = true }
rand = { version = "0.8.2", optional = true } rand = { version = "0.8.4", optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio] [target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio]
version = "1.1.0" version = "1.7.1"
default-features = false default-features = false
features = ["sync", "fs"] features = ["sync", "fs"]
[dev-dependencies] [dev-dependencies]
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" } matrix-sdk-test = { version = "0.3.0", path = "../matrix_sdk_test" }
http = "0.2.3" http = "0.2.4"
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] } tokio = { version = "1.7.1", default-features = false, features = ["rt-multi-thread", "macros"] }
tempfile = "3.2.0" tempfile = "3.2.0"
rustyline = "7.1.0" rustyline = "8.2.0"
rustyline-derive = "0.4.0" rustyline-derive = "0.4.0"
atty = "0.2.14" atty = "0.2.14"
clap = "2.33.3" clap = "2.33.3"
syntect = "4.5.0" syntect = "4.5.0"
[target.'cfg(target_arch = "wasm32")'.dev-dependencies] [target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3.19" wasm-bindgen-test = "0.3.24"

View File

@ -6,10 +6,7 @@ use atty::Stream;
use clap::{App as Argparse, AppSettings as ArgParseSettings, Arg, ArgMatches, SubCommand}; use clap::{App as Argparse, AppSettings as ArgParseSettings, Arg, ArgMatches, SubCommand};
use futures::executor::block_on; use futures::executor::block_on;
use matrix_sdk_base::{RoomInfo, Store}; use matrix_sdk_base::{RoomInfo, Store};
use ruma::{ use ruma::{events::EventType, RoomId, UserId};
events::EventType,
identifiers::{RoomId, UserId},
};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use rustyline::{ use rustyline::{
completion::{Completer, Pair}, completion::{Completer, Pair},

View File

@ -36,7 +36,7 @@ use matrix_sdk_common::{locks::Mutex, uuid::Uuid};
use matrix_sdk_crypto::{ use matrix_sdk_crypto::{
store::{CryptoStore, CryptoStoreError}, store::{CryptoStore, CryptoStoreError},
Device, EncryptionSettings, IncomingResponse, MegolmError, OlmError, OlmMachine, Device, EncryptionSettings, IncomingResponse, MegolmError, OlmError, OlmMachine,
OutgoingRequest, Sas, ToDeviceRequest, UserDevices, OutgoingRequest, ToDeviceRequest, UserDevices,
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use ruma::{ use ruma::{
@ -89,8 +89,8 @@ pub struct AdditionalUnsignedData {
pub prev_content: Option<Raw<MemberEventContent>>, pub prev_content: Option<Raw<MemberEventContent>>,
} }
/// Transform state event by hoisting `prev_content` field from `unsigned` to /// Transform an `AnySyncStateEvent` by hoisting `prev_content` field from
/// the top level. /// `unsigned` to the top level.
/// ///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in /// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// `unsigned` contrary to the C2S spec. Some more discussion can be found /// `unsigned` contrary to the C2S spec. Some more discussion can be found
@ -129,7 +129,17 @@ fn hoist_member_event(
Ok(e) Ok(e)
} }
fn hoist_room_event_prev_content( /// Transform an `AnySyncRoomEvent` by hoisting `prev_content` field from
/// `unsigned` to the top level.
///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// `unsigned` contrary to the C2S spec. Some more discussion can be found
/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
/// use this to hoist up `prev_content` to the top level.
///
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
pub fn hoist_room_event_prev_content(
event: &Raw<AnySyncRoomEvent>, event: &Raw<AnySyncRoomEvent>,
) -> StdResult<AnySyncRoomEvent, serde_json::Error> { ) -> StdResult<AnySyncRoomEvent, serde_json::Error> {
let prev_content = event let prev_content = event
@ -1202,21 +1212,6 @@ impl BaseClient {
} }
} }
/// Get a `Sas` verification object with the given flow id.
///
/// # Arguments
///
/// * `flow_id` - The unique id that identifies a interactive verification
/// flow. For in-room verifications this will be the event id of the
/// *m.key.verification.request* event that started the flow, for the
/// to-device verification flows this will be the transaction id of the
/// *m.key.verification.start* event.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
}
/// Get a specific device of a user. /// Get a specific device of a user.
/// ///
/// # Arguments /// # Arguments

View File

@ -50,7 +50,9 @@ mod rooms;
mod session; mod session;
mod store; mod store;
pub use client::{BaseClient, BaseClientConfig}; pub use client::{
hoist_and_deserialize_state_event, hoist_room_event_prev_content, BaseClient, BaseClientConfig,
};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_crypto as crypto; pub use matrix_sdk_crypto as crypto;

View File

@ -28,9 +28,9 @@ use ruma::{
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
AnySyncStateEvent, EventType, AnySyncStateEvent, EventType,
}, },
identifiers::{EventId, MxcUri, RoomId, UserId},
receipt::ReceiptType, receipt::ReceiptType,
serde::Raw, serde::Raw,
EventId, MxcUri, RoomId, UserId,
}; };
use tracing::info; use tracing::info;
@ -69,7 +69,7 @@ pub struct MemoryStore {
} }
impl MemoryStore { impl MemoryStore {
#[cfg(not(feature = "sled_state_store"))] #[allow(dead_code)]
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
sync_token: Arc::new(RwLock::new(None)), sync_token: Arc::new(RwLock::new(None)),
@ -581,15 +581,12 @@ impl StateStore for MemoryStore {
} }
#[cfg(test)] #[cfg(test)]
#[cfg(not(feature = "sled_state_store"))]
mod test { mod test {
use matrix_sdk_common::{
api::client::r0::media::get_content_thumbnail::Method,
identifiers::{event_id, mxc_uri, room_id, user_id, UserId},
receipt::ReceiptType,
uint,
};
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use ruma::{
api::client::r0::media::get_content_thumbnail::Method, event_id, mxc_uri,
receipt::ReceiptType, room_id, uint, user_id, UserId,
};
use serde_json::json; use serde_json::json;
use super::{MemoryStore, StateChanges}; use super::{MemoryStore, StateChanges};

View File

@ -732,6 +732,7 @@ impl SledStore {
.map(|u| { .map(|u| {
u.map_err(StoreError::Sled).and_then(|(key, value)| { u.map_err(StoreError::Sled).and_then(|(key, value)| {
self.deserialize_event(&value) self.deserialize_event(&value)
// TODO remove this unwrapping
.map(|receipt| { .map(|receipt| {
(decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt) (decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt)
}) })
@ -922,6 +923,7 @@ mod test {
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use ruma::{ use ruma::{
api::client::r0::media::get_content_thumbnail::Method, api::client::r0::media::get_content_thumbnail::Method,
event_id,
events::{ events::{
room::{ room::{
member::{MemberEventContent, MembershipState}, member::{MemberEventContent, MembershipState},
@ -929,10 +931,11 @@ mod test {
}, },
AnySyncStateEvent, EventType, Unsigned, AnySyncStateEvent, EventType, Unsigned,
}, },
identifiers::{event_id, mxc_uri, room_id, user_id, EventId, UserId}, mxc_uri,
receipt::ReceiptType, receipt::ReceiptType,
room_id,
serde::Raw, serde::Raw,
uint, MilliSecondsSinceUnixEpoch, uint, user_id, EventId, MilliSecondsSinceUnixEpoch, UserId,
}; };
use serde_json::json; use serde_json::json;

View File

@ -8,24 +8,24 @@ license = "Apache-2.0"
name = "matrix-sdk-common" name = "matrix-sdk-common"
readme = "README.md" readme = "README.md"
repository = "https://github.com/matrix-org/matrix-rust-sdk" repository = "https://github.com/matrix-org/matrix-rust-sdk"
version = "0.2.0" version = "0.3.0"
[dependencies] [dependencies]
async-trait = "0.1.42" async-trait = "0.1.50"
instant = { version = "0.1.9", features = ["wasm-bindgen", "now"] } instant = { version = "0.1.9", features = ["wasm-bindgen", "now"] }
ruma = { version = "0.1.2", features = ["client-api-c"] } ruma = { version = "0.2.0", features = ["client-api-c"] }
serde = "1.0.122" serde = "1.0.126"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
uuid = { version = "0.8.2", default-features = false, features = ["v4", "serde"] } uuid = { version = "0.8.2", default-features = false, features = ["v4", "serde"] }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio] [target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio]
version = "1.1.0" version = "1.7.1"
default-features = false default-features = false
features = ["rt", "sync"] features = ["rt", "sync"]
[target.'cfg(target_arch = "wasm32")'.dependencies] [target.'cfg(target_arch = "wasm32")'.dependencies]
futures = "0.3.12" futures = "0.3.15"
futures-locks = { version = "0.6.0", default-features = false } futures-locks = { version = "0.6.0", default-features = false }
wasm-bindgen-futures = "0.4" wasm-bindgen-futures = "0.4.24"
uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] } uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] }

View File

@ -12,9 +12,8 @@ use ruma::{
room::member::MemberEventContent, AnySyncRoomEvent, StateEvent, StrippedStateEvent, room::member::MemberEventContent, AnySyncRoomEvent, StateEvent, StrippedStateEvent,
SyncStateEvent, Unsigned, SyncStateEvent, Unsigned,
}, },
identifiers::{DeviceKeyAlgorithm, EventId, RoomId, UserId},
serde::Raw, serde::Raw,
DeviceIdBox, MilliSecondsSinceUnixEpoch, DeviceIdBox, DeviceKeyAlgorithm, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -8,7 +8,7 @@ license = "Apache-2.0"
name = "matrix-sdk-crypto" name = "matrix-sdk-crypto"
readme = "README.md" readme = "README.md"
repository = "https://github.com/matrix-org/matrix-rust-sdk" repository = "https://github.com/matrix-org/matrix-rust-sdk"
version = "0.2.0" version = "0.3.0"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["docs"] features = ["docs"]
@ -20,42 +20,44 @@ sled_cryptostore = ["sled"]
docs = ["sled_cryptostore"] docs = ["sled_cryptostore"]
[dependencies] [dependencies]
matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" } matrix-qrcode = { version = "0.1.0", path = "../matrix_qrcode" }
ruma = { version = "0.1.2", features = ["client-api-c", "unstable-pre-spec"] } matrix-sdk-common = { version = "0.3.0", path = "../matrix_sdk_common" }
ruma = { version = "0.2.0", features = ["client-api-c", "unstable-pre-spec"] }
olm-rs = { version = "1.0.0", features = ["serde"] } olm-rs = { version = "1.0.1", features = ["serde"] }
getrandom = "0.2.2" getrandom = "0.2.3"
serde = { version = "1.0.122", features = ["derive", "rc"] } serde = { version = "1.0.126", features = ["derive", "rc"] }
serde_json = "1.0.61" serde_json = "1.0.64"
zeroize = { version = "1.2.0", features = ["zeroize_derive"] } zeroize = { version = "1.3.0", features = ["zeroize_derive"] }
# Misc dependencies # Misc dependencies
futures = "0.3.12" futures = "0.3.15"
sled = { version = "0.34.6", optional = true } sled = { version = "0.34.6", optional = true }
thiserror = "1.0.23" thiserror = "1.0.25"
tracing = "0.1.22" tracing = "0.1.26"
atomic = "0.5.0" atomic = "0.5.0"
dashmap = "4.0.2" dashmap = "4.0.2"
sha2 = "0.9.2" sha2 = "0.9.5"
aes-gcm = "0.8.0" aes-gcm = "0.9.2"
aes-ctr = "0.6.0" aes = { version = "0.7.4", features = ["ctr"] }
pbkdf2 = { version = "0.6.0", default-features = false } pbkdf2 = { version = "0.8.0", default-features = false }
hmac = "0.10.1" hmac = "0.11.0"
base64 = "0.13.0" base64 = "0.13.0"
byteorder = "1.4.2" byteorder = "1.4.3"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] } tokio = { version = "1.7.1", default-features = false, features = ["rt-multi-thread", "macros"] }
proptest = "0.10.1" proptest = "1.0.0"
serde_json = "1.0.61" matches = "0.1.8"
serde_json = "1.0.64"
tempfile = "3.2.0" tempfile = "3.2.0"
http = "0.2.3" http = "0.2.4"
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" } matrix-sdk-test = { version = "0.3.0", path = "../matrix_sdk_test" }
indoc = "1.0.3" indoc = "1.0.3"
criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] } criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }
[target.'cfg(target_os = "linux")'.dev-dependencies] [target.'cfg(target_os = "linux")'.dev-dependencies]
pprof = { version = "0.4.2", features = ["flamegraph"] } pprof = { version = "0.4.3", features = ["flamegraph"] }
[[bench]] [[bench]]
name = "crypto_bench" name = "crypto_bench"

View File

@ -17,9 +17,9 @@ use std::{
io::{Error as IoError, ErrorKind, Read}, io::{Error as IoError, ErrorKind, Read},
}; };
use aes_ctr::{ use aes::{
cipher::{NewStreamCipher, SyncStreamCipher}, cipher::{generic_array::GenericArray, FromBlockCipher, NewBlockCipher, StreamCipher},
Aes256Ctr, Aes256, Aes256Ctr,
}; };
use base64::DecodeError; use base64::DecodeError;
use getrandom::getrandom; use getrandom::getrandom;
@ -37,17 +37,25 @@ const VERSION: &str = "v2";
/// A wrapper that transparently encrypts anything that implements `Read` as an /// A wrapper that transparently encrypts anything that implements `Read` as an
/// Matrix attachment. /// Matrix attachment.
#[derive(Debug)]
pub struct AttachmentDecryptor<'a, R: 'a + Read> { pub struct AttachmentDecryptor<'a, R: 'a + Read> {
inner_reader: &'a mut R, inner: &'a mut R,
expected_hash: Vec<u8>, expected_hash: Vec<u8>,
sha: Sha256, sha: Sha256,
aes: Aes256Ctr, aes: Aes256Ctr,
} }
impl<'a, R: 'a + Read + std::fmt::Debug> std::fmt::Debug for AttachmentDecryptor<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AttachmentDecryptor")
.field("inner", &self.inner)
.field("expected_hash", &self.expected_hash)
.finish()
}
}
impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> { impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let read_bytes = self.inner_reader.read(buf)?; let read_bytes = self.inner.read(buf)?;
if read_bytes == 0 { if read_bytes == 0 {
let hash = self.sha.finalize_reset(); let hash = self.sha.finalize_reset();
@ -126,19 +134,20 @@ impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
let hash = decode(info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?)?; let hash = decode(info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?)?;
let key = Zeroizing::from(decode_url_safe(info.web_key.k)?); let key = Zeroizing::from(decode_url_safe(info.web_key.k)?);
let iv = decode(info.iv)?; let iv = decode(info.iv)?;
let iv = GenericArray::from_exact_iter(iv).ok_or(DecryptorError::KeyNonceLength)?;
let sha = Sha256::default(); let sha = Sha256::default();
let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?; let aes = Aes256::new_from_slice(&key).map_err(|_| DecryptorError::KeyNonceLength)?;
let aes = Aes256Ctr::from_block_cipher(aes, &iv);
Ok(AttachmentDecryptor { inner_reader: input, expected_hash: hash, sha, aes }) Ok(AttachmentDecryptor { inner: input, expected_hash: hash, sha, aes })
} }
} }
/// A wrapper that transparently encrypts anything that implements `Read`. /// A wrapper that transparently encrypts anything that implements `Read`.
#[derive(Debug)]
pub struct AttachmentEncryptor<'a, R: Read + 'a> { pub struct AttachmentEncryptor<'a, R: Read + 'a> {
finished: bool, finished: bool,
inner_reader: &'a mut R, inner: &'a mut R,
web_key: JsonWebKey, web_key: JsonWebKey,
iv: String, iv: String,
hashes: BTreeMap<String, String>, hashes: BTreeMap<String, String>,
@ -146,9 +155,18 @@ pub struct AttachmentEncryptor<'a, R: Read + 'a> {
sha: Sha256, sha: Sha256,
} }
impl<'a, R: 'a + Read + std::fmt::Debug> std::fmt::Debug for AttachmentEncryptor<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AttachmentEncryptor")
.field("inner", &self.inner)
.field("finished", &self.finished)
.finish()
}
}
impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> { impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let read_bytes = self.inner_reader.read(buf)?; let read_bytes = self.inner.read(buf)?;
if read_bytes == 0 { if read_bytes == 0 {
let hash = self.sha.finalize_reset(); let hash = self.sha.finalize_reset();
@ -209,12 +227,15 @@ impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
ext: true, ext: true,
}); });
let encoded_iv = encode(&*iv); let encoded_iv = encode(&*iv);
let iv = GenericArray::from_slice(&*iv);
let key = GenericArray::from_slice(&*key);
let aes = Aes256Ctr::new_var(&*key, &*iv).expect("Cannot create AES encryption object."); let aes = Aes256::new(key);
let aes = Aes256Ctr::from_block_cipher(aes, iv);
AttachmentEncryptor { AttachmentEncryptor {
finished: false, finished: false,
inner_reader: reader, inner: reader,
iv: encoded_iv, iv: encoded_iv,
web_key, web_key,
hashes: BTreeMap::new(), hashes: BTreeMap::new(),

View File

@ -14,9 +14,9 @@
use std::io::{Cursor, Read, Seek, SeekFrom}; use std::io::{Cursor, Read, Seek, SeekFrom};
use aes_ctr::{ use aes::{
cipher::{NewStreamCipher, SyncStreamCipher}, cipher::{generic_array::GenericArray, FromBlockCipher, NewBlockCipher, StreamCipher},
Aes256Ctr, Aes256, Aes256Ctr,
}; };
use byteorder::{BigEndian, ReadBytesExt}; use byteorder::{BigEndian, ReadBytesExt};
use getrandom::getrandom; use getrandom::getrandom;
@ -161,7 +161,12 @@ fn encrypt_helper(mut plaintext: &mut [u8], passphrase: &str, rounds: u32) -> St
pbkdf2::<Hmac<Sha512>>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys); pbkdf2::<Hmac<Sha512>>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys);
let (key, hmac_key) = derived_keys.split_at(KEY_SIZE); let (key, hmac_key) = derived_keys.split_at(KEY_SIZE);
let mut aes = Aes256Ctr::new_var(key, &iv.to_be_bytes()).expect("Can't create AES object"); let key = GenericArray::from_slice(key);
let iv = iv.to_be_bytes();
let iv = GenericArray::from_slice(&iv);
let aes = Aes256::new(key);
let mut aes = Aes256Ctr::from_block_cipher(aes, iv);
aes.apply_keystream(&mut plaintext); aes.apply_keystream(&mut plaintext);
@ -169,11 +174,11 @@ fn encrypt_helper(mut plaintext: &mut [u8], passphrase: &str, rounds: u32) -> St
payload.extend(&VERSION.to_be_bytes()); payload.extend(&VERSION.to_be_bytes());
payload.extend(&salt); payload.extend(&salt);
payload.extend(&iv.to_be_bytes()); payload.extend(&*iv);
payload.extend(&rounds.to_be_bytes()); payload.extend(&rounds.to_be_bytes());
payload.extend_from_slice(plaintext); payload.extend_from_slice(plaintext);
let mut hmac = Hmac::<Sha256>::new_varkey(hmac_key).expect("Can't create HMAC object"); let mut hmac = Hmac::<Sha256>::new_from_slice(hmac_key).expect("Can't create HMAC object");
hmac.update(&payload); hmac.update(&payload);
let mac = hmac.finalize(); let mac = hmac.finalize();
@ -213,12 +218,16 @@ fn decrypt_helper(ciphertext: &str, passphrase: &str) -> Result<String, KeyExpor
pbkdf2::<Hmac<Sha512>>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys); pbkdf2::<Hmac<Sha512>>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys);
let (key, hmac_key) = derived_keys.split_at(KEY_SIZE); let (key, hmac_key) = derived_keys.split_at(KEY_SIZE);
let mut hmac = Hmac::<Sha256>::new_varkey(hmac_key).expect("Can't create an HMAC object"); let mut hmac = Hmac::<Sha256>::new_from_slice(hmac_key).expect("Can't create an HMAC object");
hmac.update(&decoded[0..ciphertext_end]); hmac.update(&decoded[0..ciphertext_end]);
hmac.verify(&mac).map_err(|_| KeyExportError::InvalidMac)?; hmac.verify(&mac).map_err(|_| KeyExportError::InvalidMac)?;
let key = GenericArray::from_slice(key);
let iv = GenericArray::from_slice(&iv);
let mut ciphertext = &mut decoded[ciphertext_start..ciphertext_end]; let mut ciphertext = &mut decoded[ciphertext_start..ciphertext_end];
let mut aes = Aes256Ctr::new_var(key, &iv).expect("Can't create an AES object"); let aes = Aes256::new(key);
let mut aes = Aes256Ctr::from_block_cipher(aes, iv);
aes.apply_keystream(&mut ciphertext); aes.apply_keystream(&mut ciphertext);
Ok(String::from_utf8(ciphertext.to_owned())?) Ok(String::from_utf8(ciphertext.to_owned())?)

View File

@ -25,15 +25,13 @@ use std::{
use atomic::Atomic; use atomic::Atomic;
use matrix_sdk_common::locks::Mutex; use matrix_sdk_common::locks::Mutex;
use ruma::{ use ruma::{
api::client::r0::keys::SignedKey, encryption::{DeviceKeys, SignedKey},
encryption::DeviceKeys,
events::{ events::{
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
room::encrypted::EncryptedEventContent, EventType, key::verification::VerificationMethod, room::encrypted::EncryptedEventContent,
}, AnyToDeviceEventContent,
identifiers::{
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId,
}, },
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId,
}; };
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value}; use serde_json::{json, Value};
@ -42,11 +40,11 @@ use tracing::warn;
use super::{atomic_bool_deserializer, atomic_bool_serializer}; use super::{atomic_bool_deserializer, atomic_bool_serializer};
use crate::{ use crate::{
error::{EventError, OlmError, OlmResult, SignatureError}, error::{EventError, OlmError, OlmResult, SignatureError},
identities::{OwnUserIdentity, UserIdentities}, identities::{ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities},
olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session, Utility}, olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session, Utility},
store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult}, store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult},
verification::VerificationMachine, verification::VerificationMachine,
OutgoingVerificationRequest, Sas, ToDeviceRequest, OutgoingVerificationRequest, Sas, ToDeviceRequest, VerificationRequest,
}; };
#[cfg(test)] #[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount}; use crate::{OlmMachine, ReadOnlyAccount};
@ -107,8 +105,8 @@ pub struct Device {
pub(crate) inner: ReadOnlyDevice, pub(crate) inner: ReadOnlyDevice,
pub(crate) private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>, pub(crate) private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
pub(crate) verification_machine: VerificationMachine, pub(crate) verification_machine: VerificationMachine,
pub(crate) own_identity: Option<OwnUserIdentity>, pub(crate) own_identity: Option<ReadOnlyOwnUserIdentity>,
pub(crate) device_owner_identity: Option<UserIdentities>, pub(crate) device_owner_identity: Option<ReadOnlyUserIdentities>,
} }
impl std::fmt::Debug for Device { impl std::fmt::Debug for Device {
@ -128,7 +126,13 @@ impl Deref for Device {
impl Device { impl Device {
/// Start a interactive verification with this `Device` /// Start a interactive verification with this `Device`
/// ///
/// Returns a `Sas` object and to-device request that needs to be sent out. /// Returns a `Sas` object and a to-device request that needs to be sent
/// out.
///
/// This method has been deprecated in the spec and the
/// [`request_verification()`] method should be used instead.
///
/// [`request_verification()`]: #method.request_verification
pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> { pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> {
let (sas, request) = self.verification_machine.start_sas(self.inner.clone()).await?; let (sas, request) = self.verification_machine.start_sas(self.inner.clone()).await?;
@ -139,6 +143,42 @@ impl Device {
} }
} }
/// Request an interacitve verification with this `Device`
///
/// Returns a `VerificationRequest` object and a to-device request that
/// needs to be sent out.
pub async fn request_verification(&self) -> (VerificationRequest, OutgoingVerificationRequest) {
self.request_verification_helper(None).await
}
/// Request an interacitve verification with this `Device`
///
/// Returns a `VerificationRequest` object and a to-device request that
/// needs to be sent out.
///
/// # Arguments
///
/// * `methods` - The verification methods that we want to support.
pub async fn request_verification_with_methods(
&self,
methods: Vec<VerificationMethod>,
) -> (VerificationRequest, OutgoingVerificationRequest) {
self.request_verification_helper(Some(methods)).await
}
async fn request_verification_helper(
&self,
methods: Option<Vec<VerificationMethod>>,
) -> (VerificationRequest, OutgoingVerificationRequest) {
self.verification_machine
.request_to_device_verification(
self.user_id(),
vec![self.device_id().to_owned()],
methods,
)
.await
}
/// Get the Olm sessions that belong to this device. /// Get the Olm sessions that belong to this device.
pub(crate) async fn get_sessions(&self) -> StoreResult<Option<Arc<Mutex<Vec<Session>>>>> { pub(crate) async fn get_sessions(&self) -> StoreResult<Option<Arc<Mutex<Vec<Session>>>>> {
if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
@ -148,9 +188,20 @@ impl Device {
} }
} }
/// Get the trust state of the device. /// Is this device considered to be verified.
pub fn trust_state(&self) -> bool { ///
self.inner.trust_state(&self.own_identity, &self.device_owner_identity) /// This method returns true if either [`is_locally_trusted()`] returns true
/// or if [`is_cross_signing_trusted()`] returns true.
///
/// [`is_locally_trusted()`]: #method.is_locally_trusted
/// [`is_cross_signing_trusted()`]: #method.is_cross_signing_trusted
pub fn verified(&self) -> bool {
self.inner.verified(&self.own_identity, &self.device_owner_identity)
}
/// Is this device considered to be verified using cross signing.
pub fn is_cross_signing_trusted(&self) -> bool {
self.inner.is_cross_signing_trusted(&self.own_identity, &self.device_owner_identity)
} }
/// Set the local trust state of the device to the given state. /// Set the local trust state of the device to the given state.
@ -176,15 +227,12 @@ impl Device {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `event_type` - The type of the event.
///
/// * `content` - The content of the event that should be encrypted. /// * `content` - The content of the event that should be encrypted.
pub(crate) async fn encrypt( pub(crate) async fn encrypt(
&self, &self,
event_type: EventType, content: AnyToDeviceEventContent,
content: Value,
) -> OlmResult<(Session, EncryptedEventContent)> { ) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner.encrypt(&*self.verification_machine.store, event_type, content).await self.inner.encrypt(&*self.verification_machine.store, content).await
} }
/// Encrypt the given inbound group session as a forwarded room key for this /// Encrypt the given inbound group session as a forwarded room key for this
@ -213,8 +261,7 @@ impl Device {
); );
}; };
let content = serde_json::to_value(content)?; self.encrypt(AnyToDeviceEventContent::ForwardedRoomKey(content)).await
self.encrypt(EventType::ForwardedRoomKey, content).await
} }
} }
@ -224,8 +271,8 @@ pub struct UserDevices {
pub(crate) inner: HashMap<DeviceIdBox, ReadOnlyDevice>, pub(crate) inner: HashMap<DeviceIdBox, ReadOnlyDevice>,
pub(crate) private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>, pub(crate) private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
pub(crate) verification_machine: VerificationMachine, pub(crate) verification_machine: VerificationMachine,
pub(crate) own_identity: Option<OwnUserIdentity>, pub(crate) own_identity: Option<ReadOnlyOwnUserIdentity>,
pub(crate) device_owner_identity: Option<UserIdentities>, pub(crate) device_owner_identity: Option<ReadOnlyUserIdentities>,
} }
impl UserDevices { impl UserDevices {
@ -243,7 +290,7 @@ impl UserDevices {
/// Returns true if there is at least one devices of this user that is /// Returns true if there is at least one devices of this user that is
/// considered to be verified, false otherwise. /// considered to be verified, false otherwise.
pub fn is_any_verified(&self) -> bool { pub fn is_any_verified(&self) -> bool {
self.inner.values().any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity)) self.inner.values().any(|d| d.verified(&self.own_identity, &self.device_owner_identity))
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
@ -347,7 +394,7 @@ impl ReadOnlyDevice {
} }
/// Is the device locally marked as trusted. /// Is the device locally marked as trusted.
pub fn is_trusted(&self) -> bool { pub fn is_locally_trusted(&self) -> bool {
self.local_trust_state() == LocalTrust::Verified self.local_trust_state() == LocalTrust::Verified
} }
@ -376,53 +423,47 @@ impl ReadOnlyDevice {
self.deleted.load(Ordering::Relaxed) self.deleted.load(Ordering::Relaxed)
} }
pub(crate) fn trust_state( pub(crate) fn verified(
&self, &self,
own_identity: &Option<OwnUserIdentity>, own_identity: &Option<ReadOnlyOwnUserIdentity>,
device_owner: &Option<UserIdentities>, device_owner: &Option<ReadOnlyUserIdentities>,
) -> bool { ) -> bool {
// TODO we want to return an enum mentioning if the trust is local, if self.is_locally_trusted() || self.is_cross_signing_trusted(own_identity, device_owner)
// only the identity is trusted, if the identity and the device are }
// trusted.
if self.is_trusted() {
// If the device is locally marked as verified just return so, no
// need to check signatures.
true
} else {
own_identity.as_ref().map_or(false, |own_identity| {
// Our own identity needs to be marked as verified.
own_identity.is_verified()
&& device_owner
.as_ref()
.map(|device_identity| match device_identity {
// If it's one of our own devices, just check that
// we signed the device.
UserIdentities::Own(_) => {
own_identity.is_device_signed(self).map_or(false, |_| true)
}
// If it's a device from someone else, first check pub(crate) fn is_cross_signing_trusted(
// that our user has signed the other user and then &self,
// check if the other user has signed this device. own_identity: &Option<ReadOnlyOwnUserIdentity>,
UserIdentities::Other(device_identity) => { device_owner: &Option<ReadOnlyUserIdentities>,
own_identity ) -> bool {
.is_identity_signed(device_identity) own_identity.as_ref().map_or(false, |own_identity| {
.map_or(false, |_| true) // Our own identity needs to be marked as verified.
&& device_identity own_identity.is_verified()
.is_device_signed(self) && device_owner
.map_or(false, |_| true) .as_ref()
} .map(|device_identity| match device_identity {
}) // If it's one of our own devices, just check that
.unwrap_or(false) // we signed the device.
}) ReadOnlyUserIdentities::Own(_) => {
} own_identity.is_device_signed(self).map_or(false, |_| true)
}
// If it's a device from someone else, first check
// that our user has signed the other user and then
// check if the other user has signed this device.
ReadOnlyUserIdentities::Other(device_identity) => {
own_identity.is_identity_signed(device_identity).map_or(false, |_| true)
&& device_identity.is_device_signed(self).map_or(false, |_| true)
}
})
.unwrap_or(false)
})
} }
pub(crate) async fn encrypt( pub(crate) async fn encrypt(
&self, &self,
store: &dyn CryptoStore, store: &dyn CryptoStore,
event_type: EventType, content: AnyToDeviceEventContent,
content: Value,
) -> OlmResult<(Session, EncryptedEventContent)> { ) -> OlmResult<(Session, EncryptedEventContent)> {
let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
k k
@ -455,7 +496,7 @@ impl ReadOnlyDevice {
return Err(OlmError::MissingSession); return Err(OlmError::MissingSession);
}; };
let message = session.encrypt(self, event_type, content).await?; let message = session.encrypt(self, content).await?;
Ok((session, message)) Ok((session, message))
} }

View File

@ -21,17 +21,16 @@ use std::{
use futures::future::join_all; use futures::future::join_all;
use matrix_sdk_common::executor::spawn; use matrix_sdk_common::executor::spawn;
use ruma::{ use ruma::{
api::client::r0::keys::get_keys::Response as KeysQueryResponse, api::client::r0::keys::get_keys::Response as KeysQueryResponse, encryption::DeviceKeys,
encryption::DeviceKeys, DeviceId, DeviceIdBox, UserId,
identifiers::{DeviceId, DeviceIdBox, UserId},
}; };
use tracing::{trace, warn}; use tracing::{trace, warn};
use crate::{ use crate::{
error::OlmResult, error::OlmResult,
identities::{ identities::{
MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserIdentities, MasterPubkey, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities,
UserIdentity, UserSigningPubkey, ReadOnlyUserIdentity, SelfSigningPubkey, UserSigningPubkey,
}, },
requests::KeysQueryRequest, requests::KeysQueryRequest,
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store}, store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
@ -247,7 +246,7 @@ impl IdentityManager {
let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? { let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? {
match &mut i { match &mut i {
UserIdentities::Own(ref mut identity) => { ReadOnlyUserIdentities::Own(ref mut identity) => {
let user_signing = if let Some(s) = response.user_signing_keys.get(user_id) let user_signing = if let Some(s) = response.user_signing_keys.get(user_id)
{ {
UserSigningPubkey::from(s) UserSigningPubkey::from(s)
@ -262,7 +261,7 @@ impl IdentityManager {
identity.update(master_key, self_signing, user_signing).map(|_| (i, false)) identity.update(master_key, self_signing, user_signing).map(|_| (i, false))
} }
UserIdentities::Other(ref mut identity) => { ReadOnlyUserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| (i, false)) identity.update(master_key, self_signing).map(|_| (i, false))
} }
} }
@ -281,8 +280,8 @@ impl IdentityManager {
continue; continue;
} }
OwnUserIdentity::new(master_key, self_signing, user_signing) ReadOnlyOwnUserIdentity::new(master_key, self_signing, user_signing)
.map(|i| (UserIdentities::Own(i), true)) .map(|i| (ReadOnlyUserIdentities::Own(i), true))
} else { } else {
warn!( warn!(
"User identity for our own user {} didn't contain a \ "User identity for our own user {} didn't contain a \
@ -295,8 +294,8 @@ impl IdentityManager {
warn!("User id mismatch in one of the cross signing keys for user {}", user_id); warn!("User id mismatch in one of the cross signing keys for user {}", user_id);
continue; continue;
} else { } else {
UserIdentity::new(master_key, self_signing) ReadOnlyUserIdentity::new(master_key, self_signing)
.map(|i| (UserIdentities::Other(i), true)) .map(|i| (ReadOnlyUserIdentities::Other(i), true))
}; };
match result { match result {

View File

@ -53,8 +53,8 @@ pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
pub(crate) use manager::IdentityManager; pub(crate) use manager::IdentityManager;
use serde::{Deserialize, Deserializer, Serializer}; use serde::{Deserialize, Deserializer, Serializer};
pub use user::{ pub use user::{
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity, MasterPubkey, OwnUserIdentity, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities,
UserSigningPubkey, ReadOnlyUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity, UserSigningPubkey,
}; };
// These methods are only here because Serialize and Deserialize don't seem to // These methods are only here because Serialize and Deserialize don't seem to

View File

@ -15,6 +15,7 @@
use std::{ use std::{
collections::{btree_map::Iter, BTreeMap}, collections::{btree_map::Iter, BTreeMap},
convert::TryFrom, convert::TryFrom,
ops::Deref,
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
@ -22,8 +23,11 @@ use std::{
}; };
use ruma::{ use ruma::{
api::client::r0::keys::{CrossSigningKey, KeyUsage}, encryption::{CrossSigningKey, KeyUsage},
DeviceKeyId, UserId, events::{
key::verification::VerificationMethod, room::message::KeyVerificationRequestEventContent,
},
DeviceIdBox, DeviceKeyId, EventId, RoomId, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::to_value; use serde_json::to_value;
@ -31,7 +35,173 @@ use serde_json::to_value;
use super::{atomic_bool_deserializer, atomic_bool_serializer}; use super::{atomic_bool_deserializer, atomic_bool_serializer};
#[cfg(test)] #[cfg(test)]
use crate::olm::PrivateCrossSigningIdentity; use crate::olm::PrivateCrossSigningIdentity;
use crate::{error::SignatureError, olm::Utility, ReadOnlyDevice}; use crate::{
error::SignatureError, olm::Utility, verification::VerificationMachine, CryptoStoreError,
OutgoingVerificationRequest, ReadOnlyDevice, VerificationRequest,
};
/// Enum over the different user identity types we can have.
#[derive(Debug, Clone)]
pub enum UserIdentities {
/// Our own user identity.
Own(OwnUserIdentity),
/// An identity belonging to another user.
Other(UserIdentity),
}
impl UserIdentities {
/// Destructure the enum into an `OwnUserIdentity` if it's of the correct
/// type.
pub fn own(self) -> Option<OwnUserIdentity> {
match self {
Self::Own(i) => Some(i),
_ => None,
}
}
/// Destructure the enum into an `UserIdentity` if it's of the correct
/// type.
pub fn other(self) -> Option<UserIdentity> {
match self {
Self::Other(i) => Some(i),
_ => None,
}
}
}
impl From<OwnUserIdentity> for UserIdentities {
fn from(i: OwnUserIdentity) -> Self {
Self::Own(i)
}
}
impl From<UserIdentity> for UserIdentities {
fn from(i: UserIdentity) -> Self {
Self::Other(i)
}
}
/// Struct representing a cross signing identity of a user.
///
/// This is the user identity of a user that isn't our own. Other users will
/// only contain a master key and a self signing key, meaning that only device
/// signatures can be checked with this identity.
///
/// This struct wraps a read-only version of the struct and allows verifications
/// to be requested to verify our own device with the user identity.
#[derive(Debug, Clone)]
pub struct OwnUserIdentity {
pub(crate) inner: ReadOnlyOwnUserIdentity,
pub(crate) verification_machine: VerificationMachine,
}
impl Deref for OwnUserIdentity {
type Target = ReadOnlyOwnUserIdentity;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl OwnUserIdentity {
/// Send a verification request to our other devices.
pub async fn request_verification(
&self,
) -> Result<(VerificationRequest, OutgoingVerificationRequest), CryptoStoreError> {
self.request_verification_helper(None).await
}
/// Send a verification request to our other devices while specifying our
/// supported methods.
///
/// # Arguments
///
/// * `methods` - The verification methods that we're supporting.
pub async fn request_verification_with_methods(
&self,
methods: Vec<VerificationMethod>,
) -> Result<(VerificationRequest, OutgoingVerificationRequest), CryptoStoreError> {
self.request_verification_helper(Some(methods)).await
}
async fn request_verification_helper(
&self,
methods: Option<Vec<VerificationMethod>>,
) -> Result<(VerificationRequest, OutgoingVerificationRequest), CryptoStoreError> {
let devices: Vec<DeviceIdBox> = self
.verification_machine
.store
.get_user_devices(self.user_id())
.await?
.into_iter()
.map(|(d, _)| d)
.filter(|d| &**d != self.verification_machine.own_device_id())
.collect();
Ok(self
.verification_machine
.request_to_device_verification(self.user_id(), devices, methods)
.await)
}
}
/// Struct representing a cross signing identity of a user.
///
/// This is the user identity of a user that isn't our own. Other users will
/// only contain a master key and a self signing key, meaning that only device
/// signatures can be checked with this identity.
///
/// This struct wraps a read-only version of the struct and allows verifications
/// to be requested to verify our own device with the user identity.
#[derive(Debug, Clone)]
pub struct UserIdentity {
pub(crate) inner: ReadOnlyUserIdentity,
pub(crate) verification_machine: VerificationMachine,
}
impl Deref for UserIdentity {
type Target = ReadOnlyUserIdentity;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl UserIdentity {
/// Create a `VerificationRequest` object after the verification request
/// content has been sent out.
pub async fn request_verification(
&self,
room_id: &RoomId,
request_event_id: &EventId,
methods: Option<Vec<VerificationMethod>>,
) -> VerificationRequest {
self.verification_machine
.request_verification(&self.inner, room_id, request_event_id, methods)
.await
}
/// Send a verification request to the given user.
///
/// The returned content needs to be sent out into a DM room with the given
/// user.
///
/// After the content has been sent out a `VerificationRequest` can be
/// started with the [`request_verification()`] method.
///
/// [`request_verification()`]: #method.request_verification
pub async fn verification_request_content(
&self,
methods: Option<Vec<VerificationMethod>>,
) -> KeyVerificationRequestEventContent {
VerificationRequest::request(
self.verification_machine.own_user_id(),
self.verification_machine.own_device_id(),
self.user_id(),
methods,
)
}
}
/// Wrapper for a cross signing key marking it as the master key. /// Wrapper for a cross signing key marking it as the master key.
/// ///
@ -213,6 +383,14 @@ impl MasterPubkey {
self.0.keys.get(key_id.as_str()).map(|k| k.as_str()) self.0.keys.get(key_id.as_str()).map(|k| k.as_str())
} }
/// Get the first available master key.
///
/// There's usually only a single master key so this will usually fetch the
/// only key.
pub fn get_first_key(&self) -> Option<&str> {
self.0.keys.values().map(|k| k.as_str()).next()
}
/// Check if the given cross signing sub-key is signed by the master key. /// Check if the given cross signing sub-key is signed by the master key.
/// ///
/// # Arguments /// # Arguments
@ -350,47 +528,47 @@ impl<'a> IntoIterator for &'a SelfSigningPubkey {
/// Enum over the different user identity types we can have. /// Enum over the different user identity types we can have.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UserIdentities { pub enum ReadOnlyUserIdentities {
/// Our own user identity. /// Our own user identity.
Own(OwnUserIdentity), Own(ReadOnlyOwnUserIdentity),
/// Identities of other users. /// Identities of other users.
Other(UserIdentity), Other(ReadOnlyUserIdentity),
} }
impl From<OwnUserIdentity> for UserIdentities { impl From<ReadOnlyOwnUserIdentity> for ReadOnlyUserIdentities {
fn from(identity: OwnUserIdentity) -> Self { fn from(identity: ReadOnlyOwnUserIdentity) -> Self {
UserIdentities::Own(identity) ReadOnlyUserIdentities::Own(identity)
} }
} }
impl From<UserIdentity> for UserIdentities { impl From<ReadOnlyUserIdentity> for ReadOnlyUserIdentities {
fn from(identity: UserIdentity) -> Self { fn from(identity: ReadOnlyUserIdentity) -> Self {
UserIdentities::Other(identity) ReadOnlyUserIdentities::Other(identity)
} }
} }
impl UserIdentities { impl ReadOnlyUserIdentities {
/// The unique user id of this identity. /// The unique user id of this identity.
pub fn user_id(&self) -> &UserId { pub fn user_id(&self) -> &UserId {
match self { match self {
UserIdentities::Own(i) => i.user_id(), ReadOnlyUserIdentities::Own(i) => i.user_id(),
UserIdentities::Other(i) => i.user_id(), ReadOnlyUserIdentities::Other(i) => i.user_id(),
} }
} }
/// Get the master key of the identity. /// Get the master key of the identity.
pub fn master_key(&self) -> &MasterPubkey { pub fn master_key(&self) -> &MasterPubkey {
match self { match self {
UserIdentities::Own(i) => i.master_key(), ReadOnlyUserIdentities::Own(i) => i.master_key(),
UserIdentities::Other(i) => i.master_key(), ReadOnlyUserIdentities::Other(i) => i.master_key(),
} }
} }
/// Get the self-signing key of the identity. /// Get the self-signing key of the identity.
pub fn self_signing_key(&self) -> &SelfSigningPubkey { pub fn self_signing_key(&self) -> &SelfSigningPubkey {
match self { match self {
UserIdentities::Own(i) => &i.self_signing_key, ReadOnlyUserIdentities::Own(i) => &i.self_signing_key,
UserIdentities::Other(i) => &i.self_signing_key, ReadOnlyUserIdentities::Other(i) => &i.self_signing_key,
} }
} }
@ -398,32 +576,32 @@ impl UserIdentities {
/// own user identity.. /// own user identity..
pub fn user_signing_key(&self) -> Option<&UserSigningPubkey> { pub fn user_signing_key(&self) -> Option<&UserSigningPubkey> {
match self { match self {
UserIdentities::Own(i) => Some(&i.user_signing_key), ReadOnlyUserIdentities::Own(i) => Some(&i.user_signing_key),
UserIdentities::Other(_) => None, ReadOnlyUserIdentities::Other(_) => None,
} }
} }
/// Destructure the enum into an `OwnUserIdentity` if it's of the correct /// Destructure the enum into an `ReadOnlyOwnUserIdentity` if it's of the
/// type. /// correct type.
pub fn own(&self) -> Option<&OwnUserIdentity> { pub fn own(&self) -> Option<&ReadOnlyOwnUserIdentity> {
match self { match self {
UserIdentities::Own(i) => Some(i), ReadOnlyUserIdentities::Own(i) => Some(i),
_ => None, _ => None,
} }
} }
/// Destructure the enum into an `UserIdentity` if it's of the correct /// Destructure the enum into an `UserIdentity` if it's of the correct
/// type. /// type.
pub fn other(&self) -> Option<&UserIdentity> { pub fn other(&self) -> Option<&ReadOnlyUserIdentity> {
match self { match self {
UserIdentities::Other(i) => Some(i), ReadOnlyUserIdentities::Other(i) => Some(i),
_ => None, _ => None,
} }
} }
} }
impl PartialEq for UserIdentities { impl PartialEq for ReadOnlyUserIdentities {
fn eq(&self, other: &UserIdentities) -> bool { fn eq(&self, other: &ReadOnlyUserIdentities) -> bool {
self.user_id() == other.user_id() self.user_id() == other.user_id()
} }
} }
@ -434,13 +612,13 @@ impl PartialEq for UserIdentities {
/// only contain a master key and a self signing key, meaning that only device /// only contain a master key and a self signing key, meaning that only device
/// signatures can be checked with this identity. /// signatures can be checked with this identity.
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UserIdentity { pub struct ReadOnlyUserIdentity {
user_id: Arc<UserId>, user_id: Arc<UserId>,
pub(crate) master_key: MasterPubkey, pub(crate) master_key: MasterPubkey,
self_signing_key: SelfSigningPubkey, self_signing_key: SelfSigningPubkey,
} }
impl UserIdentity { impl ReadOnlyUserIdentity {
/// Create a new user identity with the given master and self signing key. /// Create a new user identity with the given master and self signing key.
/// ///
/// # Arguments /// # Arguments
@ -535,7 +713,7 @@ impl UserIdentity {
/// This identity can verify other identities as well as devices belonging to /// This identity can verify other identities as well as devices belonging to
/// the identity. /// the identity.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OwnUserIdentity { pub struct ReadOnlyOwnUserIdentity {
user_id: Arc<UserId>, user_id: Arc<UserId>,
master_key: MasterPubkey, master_key: MasterPubkey,
self_signing_key: SelfSigningPubkey, self_signing_key: SelfSigningPubkey,
@ -547,7 +725,7 @@ pub struct OwnUserIdentity {
verified: Arc<AtomicBool>, verified: Arc<AtomicBool>,
} }
impl OwnUserIdentity { impl ReadOnlyOwnUserIdentity {
/// Create a new own user identity with the given master, self signing, and /// Create a new own user identity with the given master, self signing, and
/// user signing key. /// user signing key.
/// ///
@ -607,7 +785,10 @@ impl OwnUserIdentity {
/// ///
/// Returns an empty result if the signature check succeeded, otherwise a /// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed. /// SignatureError indicating why the check failed.
pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> { pub fn is_identity_signed(
&self,
identity: &ReadOnlyUserIdentity,
) -> Result<(), SignatureError> {
self.user_signing_key.verify_master_key(&identity.master_key) self.user_signing_key.verify_master_key(&identity.master_key)
} }
@ -684,7 +865,7 @@ pub(crate) mod test {
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use ruma::{api::client::r0::keys::get_keys::Response as KeyQueryResponse, user_id}; use ruma::{api::client::r0::keys::get_keys::Response as KeyQueryResponse, user_id};
use super::{OwnUserIdentity, UserIdentities, UserIdentity}; use super::{ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, ReadOnlyUserIdentity};
use crate::{ use crate::{
identities::{ identities::{
manager::test::{other_key_query, own_key_query}, manager::test::{other_key_query, own_key_query},
@ -702,28 +883,29 @@ pub(crate) mod test {
(first, second) (first, second)
} }
fn own_identity(response: &KeyQueryResponse) -> OwnUserIdentity { fn own_identity(response: &KeyQueryResponse) -> ReadOnlyOwnUserIdentity {
let user_id = user_id!("@example:localhost"); let user_id = user_id!("@example:localhost");
let master_key = response.master_keys.get(&user_id).unwrap(); let master_key = response.master_keys.get(&user_id).unwrap();
let user_signing = response.user_signing_keys.get(&user_id).unwrap(); let user_signing = response.user_signing_keys.get(&user_id).unwrap();
let self_signing = response.self_signing_keys.get(&user_id).unwrap(); let self_signing = response.self_signing_keys.get(&user_id).unwrap();
OwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into()).unwrap() ReadOnlyOwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into())
.unwrap()
} }
pub(crate) fn get_own_identity() -> OwnUserIdentity { pub(crate) fn get_own_identity() -> ReadOnlyOwnUserIdentity {
own_identity(&own_key_query()) own_identity(&own_key_query())
} }
pub(crate) fn get_other_identity() -> UserIdentity { pub(crate) fn get_other_identity() -> ReadOnlyUserIdentity {
let user_id = user_id!("@example2:localhost"); let user_id = user_id!("@example2:localhost");
let response = other_key_query(); let response = other_key_query();
let master_key = response.master_keys.get(&user_id).unwrap(); let master_key = response.master_keys.get(&user_id).unwrap();
let self_signing = response.self_signing_keys.get(&user_id).unwrap(); let self_signing = response.self_signing_keys.get(&user_id).unwrap();
UserIdentity::new(master_key.into(), self_signing.into()).unwrap() ReadOnlyUserIdentity::new(master_key.into(), self_signing.into()).unwrap()
} }
#[test] #[test]
@ -735,7 +917,8 @@ pub(crate) mod test {
let user_signing = response.user_signing_keys.get(&user_id).unwrap(); let user_signing = response.user_signing_keys.get(&user_id).unwrap();
let self_signing = response.self_signing_keys.get(&user_id).unwrap(); let self_signing = response.self_signing_keys.get(&user_id).unwrap();
OwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into()).unwrap(); ReadOnlyOwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into())
.unwrap();
} }
#[test] #[test]
@ -765,7 +948,7 @@ pub(crate) mod test {
verification_machine: verification_machine.clone(), verification_machine: verification_machine.clone(),
private_identity: private_identity.clone(), private_identity: private_identity.clone(),
own_identity: Some(identity.clone()), own_identity: Some(identity.clone()),
device_owner_identity: Some(UserIdentities::Own(identity.clone())), device_owner_identity: Some(ReadOnlyUserIdentities::Own(identity.clone())),
}; };
let second = Device { let second = Device {
@ -773,18 +956,18 @@ pub(crate) mod test {
verification_machine, verification_machine,
private_identity, private_identity,
own_identity: Some(identity.clone()), own_identity: Some(identity.clone()),
device_owner_identity: Some(UserIdentities::Own(identity.clone())), device_owner_identity: Some(ReadOnlyUserIdentities::Own(identity.clone())),
}; };
assert!(!second.trust_state()); assert!(!second.is_locally_trusted());
assert!(!second.is_trusted()); assert!(!second.is_cross_signing_trusted());
assert!(!first.trust_state()); assert!(!first.is_locally_trusted());
assert!(!first.is_trusted()); assert!(!first.is_cross_signing_trusted());
identity.mark_as_verified(); identity.mark_as_verified();
assert!(second.trust_state()); assert!(second.verified());
assert!(!first.trust_state()); assert!(!first.verified());
} }
#[async_test] #[async_test]
@ -803,7 +986,7 @@ pub(crate) mod test {
Arc::new(MemoryStore::new()), Arc::new(MemoryStore::new()),
); );
let public_identity = identity.as_public_identity().await.unwrap(); let public_identity = identity.to_public_identity().await.unwrap();
let mut device = Device { let mut device = Device {
inner: device, inner: device,
@ -813,12 +996,12 @@ pub(crate) mod test {
device_owner_identity: Some(public_identity.clone().into()), device_owner_identity: Some(public_identity.clone().into()),
}; };
assert!(!device.trust_state()); assert!(!device.verified());
let mut device_keys = device.as_device_keys(); let mut device_keys = device.as_device_keys();
identity.sign_device_keys(&mut device_keys).await.unwrap(); identity.sign_device_keys(&mut device_keys).await.unwrap();
device.inner.signatures = Arc::new(device_keys.signatures); device.inner.signatures = Arc::new(device_keys.signatures);
assert!(device.trust_state()); assert!(device.verified());
} }
} }

View File

@ -20,21 +20,20 @@
// If we don't trust the device store an object that remembers the request and // If we don't trust the device store an object that remembers the request and
// let the users introspect that object. // let the users introspect that object.
use std::{collections::BTreeMap, sync::Arc}; use std::sync::Arc;
use dashmap::{mapref::entry::Entry, DashMap, DashSet}; use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use matrix_sdk_common::uuid::Uuid; use matrix_sdk_common::uuid::Uuid;
use ruma::{ use ruma::{
api::client::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestToDeviceEventContent}, room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestToDeviceEventContent},
AnyToDeviceEvent, EventType, ToDeviceEvent, AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
}, },
identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId}, to_device::DeviceIdOrAllDevices,
DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::value::to_raw_value;
use thiserror::Error; use thiserror::Error;
use tracing::{error, info, trace, warn}; use tracing::{error, info, trace, warn};
@ -150,7 +149,7 @@ pub struct OutgoingKeyRequest {
} }
impl OutgoingKeyRequest { impl OutgoingKeyRequest {
fn to_request(&self, own_device_id: &DeviceId) -> Result<OutgoingRequest, serde_json::Error> { fn to_request(&self, own_device_id: &DeviceId) -> OutgoingRequest {
let content = RoomKeyRequestToDeviceEventContent::new( let content = RoomKeyRequestToDeviceEventContent::new(
Action::Request, Action::Request,
Some(self.info.clone()), Some(self.info.clone()),
@ -158,13 +157,17 @@ impl OutgoingKeyRequest {
self.request_id.to_string(), self.request_id.to_string(),
); );
wrap_key_request_content(self.request_recipient.clone(), self.request_id, &content) let request = ToDeviceRequest::new_with_id(
&self.request_recipient,
DeviceIdOrAllDevices::AllDevices,
AnyToDeviceEventContent::RoomKeyRequest(content),
self.request_id,
);
OutgoingRequest { request_id: request.txn_id, request: Arc::new(request.into()) }
} }
fn to_cancellation( fn to_cancellation(&self, own_device_id: &DeviceId) -> OutgoingRequest {
&self,
own_device_id: &DeviceId,
) -> Result<OutgoingRequest, serde_json::Error> {
let content = RoomKeyRequestToDeviceEventContent::new( let content = RoomKeyRequestToDeviceEventContent::new(
Action::CancelRequest, Action::CancelRequest,
None, None,
@ -172,8 +175,13 @@ impl OutgoingKeyRequest {
self.request_id.to_string(), self.request_id.to_string(),
); );
let id = Uuid::new_v4(); let request = ToDeviceRequest::new(
wrap_key_request_content(self.request_recipient.clone(), id, &content) &self.request_recipient,
DeviceIdOrAllDevices::AllDevices,
AnyToDeviceEventContent::RoomKeyRequest(content),
);
OutgoingRequest { request_id: request.txn_id, request: Arc::new(request.into()) }
} }
} }
@ -187,26 +195,6 @@ impl PartialEq for OutgoingKeyRequest {
} }
} }
fn wrap_key_request_content(
recipient: UserId,
id: Uuid,
content: &RoomKeyRequestToDeviceEventContent,
) -> Result<OutgoingRequest, serde_json::Error> {
let mut messages = BTreeMap::new();
messages
.entry(recipient)
.or_insert_with(BTreeMap::new)
.insert(DeviceIdOrAllDevices::AllDevices, to_raw_value(content)?);
Ok(OutgoingRequest {
request_id: id,
request: Arc::new(
ToDeviceRequest { event_type: EventType::RoomKeyRequest, txn_id: id, messages }.into(),
),
})
}
impl KeyRequestMachine { impl KeyRequestMachine {
pub fn new( pub fn new(
user_id: Arc<UserId>, user_id: Arc<UserId>,
@ -229,13 +217,14 @@ impl KeyRequestMachine {
/// Load stored outgoing requests that were not yet sent out. /// Load stored outgoing requests that were not yet sent out.
async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> { async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
self.store Ok(self
.store
.get_unsent_key_requests() .get_unsent_key_requests()
.await? .await?
.into_iter() .into_iter()
.filter(|i| !i.sent_out) .filter(|i| !i.sent_out)
.map(|info| info.to_request(self.device_id()).map_err(CryptoStoreError::from)) .map(|info| info.to_request(self.device_id()))
.collect() .collect())
} }
/// Our own user id. /// Our own user id.
@ -448,23 +437,15 @@ impl KeyRequestMachine {
let (used_session, content) = let (used_session, content) =
device.encrypt_session(session.clone(), message_index).await?; device.encrypt_session(session.clone(), message_index).await?;
let id = Uuid::new_v4(); let request = ToDeviceRequest::new(
let mut messages = BTreeMap::new(); device.user_id(),
device.device_id().to_owned(),
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert( AnyToDeviceEventContent::RoomEncrypted(content),
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
); );
let request = OutgoingRequest { let request =
request_id: id, OutgoingRequest { request_id: request.txn_id, request: Arc::new(request.into()) };
request: Arc::new( self.outgoing_to_device_requests.insert(request.request_id, request);
ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages }
.into(),
),
};
self.outgoing_to_device_requests.insert(id, request);
Ok(used_session) Ok(used_session)
} }
@ -498,7 +479,7 @@ impl KeyRequestMachine {
.flatten(); .flatten();
let own_device_check = || { let own_device_check = || {
if device.trust_state() { if device.verified() {
Ok(None) Ok(None)
} else { } else {
Err(KeyshareDecision::UntrustedDevice) Err(KeyshareDecision::UntrustedDevice)
@ -594,8 +575,8 @@ impl KeyRequestMachine {
let request = self.store.get_key_request_by_info(&key_info).await?; let request = self.store.get_key_request_by_info(&key_info).await?;
if let Some(request) = request { if let Some(request) = request {
let cancel = request.to_cancellation(self.device_id())?; let cancel = request.to_cancellation(self.device_id());
let request = request.to_request(self.device_id())?; let request = request.to_request(self.device_id());
Ok((Some(cancel), request)) Ok((Some(cancel), request))
} else { } else {
@ -618,7 +599,7 @@ impl KeyRequestMachine {
sent_out: false, sent_out: false,
}; };
let outgoing_request = request.to_request(self.device_id())?; let outgoing_request = request.to_request(self.device_id());
self.save_outgoing_key_info(request).await?; self.save_outgoing_key_info(request).await?;
Ok(outgoing_request) Ok(outgoing_request)
@ -717,7 +698,7 @@ impl KeyRequestMachine {
// can delete it in one transaction. // can delete it in one transaction.
self.delete_key_info(&key_info).await?; self.delete_key_info(&key_info).await?;
let request = key_info.to_cancellation(self.device_id())?; let request = key_info.to_cancellation(self.device_id());
self.outgoing_to_device_requests.insert(request.request_id, request); self.outgoing_to_device_requests.insert(request.request_id, request);
Ok(()) Ok(())
@ -789,13 +770,14 @@ mod test {
use matrix_sdk_common::locks::Mutex; use matrix_sdk_common::locks::Mutex;
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use ruma::{ use ruma::{
api::client::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
room::encrypted::EncryptedEventContent, room::encrypted::EncryptedEventContent,
room_key_request::RoomKeyRequestToDeviceEventContent, AnyToDeviceEvent, ToDeviceEvent, room_key_request::RoomKeyRequestToDeviceEventContent, AnyToDeviceEvent, ToDeviceEvent,
}, },
room_id, user_id, DeviceIdBox, RoomId, UserId, room_id,
to_device::DeviceIdOrAllDevices,
user_id, DeviceIdBox, RoomId, UserId,
}; };
use super::{KeyRequestMachine, KeyshareDecision}; use super::{KeyRequestMachine, KeyshareDecision};
@ -1203,8 +1185,7 @@ mod test {
.unwrap() .unwrap()
.get(&DeviceIdOrAllDevices::AllDevices) .get(&DeviceIdOrAllDevices::AllDevices)
.unwrap(); .unwrap();
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent = content.deserialize_as().unwrap();
serde_json::from_str(content.get()).unwrap();
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap(); alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
@ -1233,7 +1214,7 @@ mod test {
.unwrap() .unwrap()
.get(&DeviceIdOrAllDevices::DeviceId(alice_device_id())) .get(&DeviceIdOrAllDevices::DeviceId(alice_device_id()))
.unwrap(); .unwrap();
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); let content: EncryptedEventContent = content.deserialize_as().unwrap();
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
@ -1336,8 +1317,7 @@ mod test {
.unwrap() .unwrap()
.get(&DeviceIdOrAllDevices::AllDevices) .get(&DeviceIdOrAllDevices::AllDevices)
.unwrap(); .unwrap();
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent = content.deserialize_as().unwrap();
serde_json::from_str(content.get()).unwrap();
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap(); alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
@ -1382,7 +1362,7 @@ mod test {
.unwrap() .unwrap()
.get(&DeviceIdOrAllDevices::DeviceId(alice_device_id())) .get(&DeviceIdOrAllDevices::DeviceId(alice_device_id()))
.unwrap(); .unwrap();
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); let content: EncryptedEventContent = content.deserialize_as().unwrap();
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();

View File

@ -45,9 +45,11 @@ pub use file_encryption::{
DecryptorError, EncryptionInfo, KeyExportError, DecryptorError, EncryptionInfo, KeyExportError,
}; };
pub use identities::{ pub use identities::{
Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserDevices, UserIdentities, UserIdentity, Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, ReadOnlyOwnUserIdentity,
ReadOnlyUserIdentities, ReadOnlyUserIdentity, UserDevices, UserIdentities, UserIdentity,
}; };
pub use machine::OlmMachine; pub use machine::OlmMachine;
pub use matrix_qrcode;
pub use olm::EncryptionSettings; pub use olm::EncryptionSettings;
pub(crate) use olm::ReadOnlyAccount; pub(crate) use olm::ReadOnlyAccount;
pub use requests::{ pub use requests::{
@ -55,4 +57,6 @@ pub use requests::{
OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest,
}; };
pub use store::CryptoStoreError; pub use store::CryptoStoreError;
pub use verification::{AcceptSettings, Sas, VerificationRequest}; pub use verification::{
AcceptSettings, CancelInfo, QrVerification, Sas, Verification, VerificationRequest,
};

View File

@ -46,7 +46,7 @@ use tracing::{debug, error, info, trace, warn};
use crate::store::sled::SledStore; use crate::store::sled::SledStore;
use crate::{ use crate::{
error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}, error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
identities::{Device, IdentityManager, UserDevices}, identities::{user::UserIdentities, Device, IdentityManager, UserDevices},
key_request::KeyRequestMachine, key_request::KeyRequestMachine,
olm::{ olm::{
Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
@ -59,7 +59,7 @@ use crate::{
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult, Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
Store, Store,
}, },
verification::{Sas, VerificationMachine, VerificationRequest}, verification::{Verification, VerificationMachine, VerificationRequest},
ToDeviceRequest, ToDeviceRequest,
}; };
@ -379,7 +379,7 @@ impl OlmMachine {
*identity = id; *identity = id;
let public = identity.as_public_identity().await.expect( let public = identity.to_public_identity().await.expect(
"Couldn't create a public version of the identity from a new private identity", "Couldn't create a public version of the identity from a new private identity",
); );
@ -717,21 +717,27 @@ impl OlmMachine {
Ok(()) Ok(())
} }
/// Get a `Sas` verification object with the given flow id. /// Get a verification object for the given user id with the given flow id.
pub fn get_verification(&self, flow_id: &str) -> Option<Sas> { pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.verification_machine.get_sas(flow_id) self.verification_machine.get_verification(user_id, flow_id)
} }
/// Get a verification request object with the given flow id. /// Get a verification request object with the given flow id.
pub fn get_verification_request( pub fn get_verification_request(
&self, &self,
user_id: &UserId,
flow_id: impl AsRef<str>, flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> { ) -> Option<VerificationRequest> {
self.verification_machine.get_request(flow_id) self.verification_machine.get_request(user_id, flow_id)
} }
async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) { /// Get all the verification requests of a given user.
self.account.update_uploaded_key_count(key_count).await; pub fn get_verification_requests(&self, user_id: &UserId) -> Vec<VerificationRequest> {
self.verification_machine.get_requests(user_id)
}
fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
self.account.update_uploaded_key_count(key_count);
} }
async fn handle_to_device_event(&self, event: &AnyToDeviceEvent) { async fn handle_to_device_event(&self, event: &AnyToDeviceEvent) {
@ -749,11 +755,11 @@ impl OlmMachine {
| AnyToDeviceEvent::KeyVerificationStart(..) => { | AnyToDeviceEvent::KeyVerificationStart(..) => {
self.handle_verification_event(event).await; self.handle_verification_event(event).await;
} }
AnyToDeviceEvent::Dummy(_) => {} AnyToDeviceEvent::Dummy(_)
AnyToDeviceEvent::RoomKey(_) => {} | AnyToDeviceEvent::RoomKey(_)
AnyToDeviceEvent::ForwardedRoomKey(_) => {} | AnyToDeviceEvent::ForwardedRoomKey(_)
AnyToDeviceEvent::RoomEncrypted(_) => {} | AnyToDeviceEvent::RoomEncrypted(_) => {}
AnyToDeviceEvent::Custom(_) => {} _ => {}
} }
} }
@ -783,14 +789,14 @@ impl OlmMachine {
one_time_keys_counts: &BTreeMap<DeviceKeyAlgorithm, UInt>, one_time_keys_counts: &BTreeMap<DeviceKeyAlgorithm, UInt>,
) -> OlmResult<ToDevice> { ) -> OlmResult<ToDevice> {
// Remove verification objects that have expired or are done. // Remove verification objects that have expired or are done.
self.verification_machine.garbage_collect(); let mut events = self.verification_machine.garbage_collect();
// Always save the account, a new session might get created which also // Always save the account, a new session might get created which also
// touches the account. // touches the account.
let mut changes = let mut changes =
Changes { account: Some(self.account.inner.clone()), ..Default::default() }; Changes { account: Some(self.account.inner.clone()), ..Default::default() };
self.update_one_time_key_count(one_time_keys_counts).await; self.update_one_time_key_count(one_time_keys_counts);
for user_id in &changed_devices.changed { for user_id in &changed_devices.changed {
if let Err(e) = self.identity_manager.mark_user_as_changed(user_id).await { if let Err(e) = self.identity_manager.mark_user_as_changed(user_id).await {
@ -798,8 +804,6 @@ impl OlmMachine {
} }
} }
let mut events = Vec::new();
for mut raw_event in to_device_events.events { for mut raw_event in to_device_events.events {
let event = match raw_event.deserialize() { let event = match raw_event.deserialize() {
Ok(e) => e, Ok(e) => e,
@ -922,7 +926,7 @@ impl OlmMachine {
.unwrap_or(false) .unwrap_or(false)
}) { }) {
if (self.user_id() == device.user_id() && self.device_id() == device.device_id()) if (self.user_id() == device.user_id() && self.device_id() == device.device_id())
|| device.is_trusted() || device.verified()
{ {
VerificationState::Trusted VerificationState::Trusted
} else { } else {
@ -1048,6 +1052,18 @@ impl OlmMachine {
self.store.get_device(user_id, device_id).await self.store.get_device(user_id, device_id).await
} }
/// Get the cross signing user identity of a user.
///
/// # Arguments
///
/// * `user_id` - The unique id of the user that the identity belongs to
///
/// Returns a `UserIdentities` enum if one is found and the crypto store
/// didn't throw an error.
pub async fn get_identity(&self, user_id: &UserId) -> StoreResult<Option<UserIdentities>> {
self.store.get_identity(user_id).await
}
/// Get a map holding all the devices of an user. /// Get a map holding all the devices of an user.
/// ///
/// # Arguments /// # Arguments
@ -1225,22 +1241,22 @@ pub(crate) mod test {
use matrix_sdk_test::test_json; use matrix_sdk_test::test_json;
use ruma::{ use ruma::{
api::{ api::{
client::r0::keys::{claim_keys, get_keys, upload_keys, OneTimeKey}, client::r0::keys::{claim_keys, get_keys, upload_keys},
IncomingResponse, IncomingResponse,
}, },
encryption::OneTimeKey,
event_id,
events::{ events::{
dummy::DummyToDeviceEventContent,
room::{ room::{
encrypted::EncryptedEventContent, encrypted::EncryptedEventContent,
message::{MessageEventContent, MessageType}, message::{MessageEventContent, MessageType},
}, },
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, AnyToDeviceEvent, AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, AnyToDeviceEvent,
EventType, SyncMessageEvent, ToDeviceEvent, Unsigned, AnyToDeviceEventContent, SyncMessageEvent, ToDeviceEvent, Unsigned,
}, },
identifiers::{ room_id, uint, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId,
event_id, room_id, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, UserId, MilliSecondsSinceUnixEpoch, UserId,
},
serde::Raw,
uint, MilliSecondsSinceUnixEpoch,
}; };
use serde_json::json; use serde_json::json;
@ -1285,12 +1301,16 @@ pub(crate) mod test {
fn to_device_requests_to_content(requests: Vec<Arc<ToDeviceRequest>>) -> EncryptedEventContent { fn to_device_requests_to_content(requests: Vec<Arc<ToDeviceRequest>>) -> EncryptedEventContent {
let to_device_request = &requests[0]; let to_device_request = &requests[0];
let content: Raw<EncryptedEventContent> = serde_json::from_str( to_device_request
to_device_request.messages.values().next().unwrap().values().next().unwrap().get(), .messages
) .values()
.unwrap(); .next()
.unwrap()
content.deserialize().unwrap() .values()
.next()
.unwrap()
.deserialize_as()
.unwrap()
} }
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
@ -1352,7 +1372,10 @@ pub(crate) mod test {
let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap(); let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
let (session, content) = bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap(); let (session, content) = bob_device
.encrypt(AnyToDeviceEventContent::Dummy(DummyToDeviceEventContent::new()))
.await
.unwrap();
alice.store.save_sessions(&[session]).await.unwrap(); alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent { sender: alice.user_id().clone(), content }; let event = ToDeviceEvent { sender: alice.user_id().clone(), content };
@ -1387,6 +1410,10 @@ pub(crate) mod test {
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50)); response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine.receive_keys_upload_response(&response).await.unwrap(); machine.receive_keys_upload_response(&response).await.unwrap();
assert!(!machine.should_upload_keys().await); assert!(!machine.should_upload_keys().await);
response.one_time_key_counts.remove(&DeviceKeyAlgorithm::SignedCurve25519);
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(!machine.should_upload_keys().await);
} }
#[tokio::test] #[tokio::test]
@ -1587,7 +1614,11 @@ pub(crate) mod test {
let event = ToDeviceEvent { let event = ToDeviceEvent {
sender: alice.user_id().clone(), sender: alice.user_id().clone(),
content: bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap().1, content: bob_device
.encrypt(AnyToDeviceEventContent::Dummy(DummyToDeviceEventContent::new()))
.await
.unwrap()
.1,
}; };
let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap(); let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap();
@ -1754,14 +1785,18 @@ pub(crate) mod test {
let bob_device = alice.get_device(bob.user_id(), bob.device_id()).await.unwrap().unwrap(); let bob_device = alice.get_device(bob.user_id(), bob.device_id()).await.unwrap().unwrap();
assert!(!bob_device.is_trusted()); assert!(!bob_device.verified());
let (alice_sas, request) = bob_device.start_verification().await.unwrap(); let (alice_sas, request) = bob_device.start_verification().await.unwrap();
let event = request_to_event(alice.user_id(), &request.into()); let event = request_to_event(alice.user_id(), &request.into());
bob.handle_verification_event(&event).await; bob.handle_verification_event(&event).await;
let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap(); let bob_sas = bob
.get_verification(alice.user_id(), alice_sas.flow_id().as_str())
.unwrap()
.sas_v1()
.unwrap();
assert!(alice_sas.emoji().is_none()); assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none()); assert!(bob_sas.emoji().is_none());
@ -1813,14 +1848,14 @@ pub(crate) mod test {
.unwrap(); .unwrap();
assert!(alice_sas.is_done()); assert!(alice_sas.is_done());
assert!(bob_device.is_trusted()); assert!(bob_device.verified());
let alice_device = let alice_device =
bob.get_device(alice.user_id(), alice.device_id()).await.unwrap().unwrap(); bob.get_device(alice.user_id(), alice.device_id()).await.unwrap().unwrap();
assert!(!alice_device.is_trusted()); assert!(!alice_device.verified());
bob.handle_verification_event(&event).await; bob.handle_verification_event(&event).await;
assert!(bob_sas.is_done()); assert!(bob_sas.is_done());
assert!(alice_device.is_trusted()); assert!(alice_device.verified());
} }
} }

View File

@ -14,11 +14,11 @@
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::TryInto,
fmt, fmt,
ops::Deref, ops::Deref,
sync::{ sync::{
atomic::{AtomicBool, AtomicI64, Ordering}, atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Arc,
}, },
}; };
@ -30,23 +30,16 @@ use olm_rs::{
session::{OlmMessage, PreKeyMessage}, session::{OlmMessage, PreKeyMessage},
PicklingMode, PicklingMode,
}; };
#[cfg(test)]
use ruma::events::EventType;
use ruma::{ use ruma::{
api::client::r0::keys::{ api::client::r0::keys::{upload_keys, upload_signatures::Request as SignatureUploadRequest},
upload_keys, upload_signatures::Request as SignatureUploadRequest, OneTimeKey, SignedKey, encryption::{DeviceKeys, OneTimeKey, SignedKey},
},
encryption::DeviceKeys,
events::{ events::{
room::encrypted::{EncryptedEventContent, EncryptedEventScheme}, room::encrypted::{EncryptedEventContent, EncryptedEventScheme},
AnyToDeviceEvent, ToDeviceEvent, AnyToDeviceEvent, ToDeviceEvent,
}, },
identifiers::{
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId,
UserId,
},
serde::{CanonicalJsonValue, Raw}, serde::{CanonicalJsonValue, Raw},
UInt, DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UInt,
UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@ -188,9 +181,22 @@ impl Account {
} }
} }
pub async fn update_uploaded_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) { pub fn update_uploaded_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
if let Some(count) = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519) { if let Some(count) = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519) {
let count: u64 = (*count).into(); let count: u64 = (*count).into();
let old_count = self.inner.uploaded_key_count();
// Some servers might always return the key counts in the sync
// response, we don't want to the logs with noop changes if they do
// so.
if count != old_count {
debug!(
"Updated uploaded one-time key count {} -> {}.",
self.inner.uploaded_key_count(),
count
);
}
self.inner.update_uploaded_key_count(count); self.inner.update_uploaded_key_count(count);
} }
} }
@ -204,16 +210,8 @@ impl Account {
} }
self.inner.mark_as_shared(); self.inner.mark_as_shared();
let one_time_key_count = debug!("Marking one-time keys as published");
response.one_time_key_counts.get(&DeviceKeyAlgorithm::SignedCurve25519); self.update_uploaded_key_count(&response.one_time_key_counts);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
debug!(
"Updated uploaded one-time key count {} -> {}, marking keys as published",
self.inner.uploaded_key_count(),
count
);
self.inner.update_uploaded_key_count(count);
self.inner.mark_keys_as_published().await; self.inner.mark_keys_as_published().await;
self.store.save_account(self.inner.clone()).await?; self.store.save_account(self.inner.clone()).await?;
@ -439,7 +437,7 @@ pub struct ReadOnlyAccount {
/// this is None, no action will be taken. After a sync request the client /// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the /// needs to set this for us, depending on the count we will suggest the
/// client to upload new keys. /// client to upload new keys.
uploaded_signed_key_count: Arc<AtomicI64>, uploaded_signed_key_count: Arc<AtomicU64>,
} }
/// A typed representation of a base64 encoded string containing the account /// A typed representation of a base64 encoded string containing the account
@ -475,7 +473,7 @@ pub struct PickledAccount {
/// Was the account shared. /// Was the account shared.
pub shared: bool, pub shared: bool,
/// The number of uploaded one-time keys we have on the server. /// The number of uploaded one-time keys we have on the server.
pub uploaded_signed_key_count: i64, pub uploaded_signed_key_count: u64,
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
@ -506,7 +504,7 @@ impl ReadOnlyAccount {
inner: Arc::new(Mutex::new(account)), inner: Arc::new(Mutex::new(account)),
identity_keys: Arc::new(identity_keys), identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::new(false)), shared: Arc::new(AtomicBool::new(false)),
uploaded_signed_key_count: Arc::new(AtomicI64::new(0)), uploaded_signed_key_count: Arc::new(AtomicU64::new(0)),
} }
} }
@ -531,18 +529,17 @@ impl ReadOnlyAccount {
/// ///
/// * `new_count` - The new count that was reported by the server. /// * `new_count` - The new count that was reported by the server.
pub(crate) fn update_uploaded_key_count(&self, new_count: u64) { pub(crate) fn update_uploaded_key_count(&self, new_count: u64) {
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX); self.uploaded_signed_key_count.store(new_count, Ordering::SeqCst);
self.uploaded_signed_key_count.store(key_count, Ordering::Relaxed);
} }
/// Get the currently known uploaded key count. /// Get the currently known uploaded key count.
pub fn uploaded_key_count(&self) -> i64 { pub fn uploaded_key_count(&self) -> u64 {
self.uploaded_signed_key_count.load(Ordering::Relaxed) self.uploaded_signed_key_count.load(Ordering::SeqCst)
} }
/// Has the account been shared with the server. /// Has the account been shared with the server.
pub fn shared(&self) -> bool { pub fn shared(&self) -> bool {
self.shared.load(Ordering::Relaxed) self.shared.load(Ordering::SeqCst)
} }
/// Mark the account as shared. /// Mark the account as shared.
@ -550,7 +547,7 @@ impl ReadOnlyAccount {
/// Messages shouldn't be encrypted with the session before it has been /// Messages shouldn't be encrypted with the session before it has been
/// shared. /// shared.
pub(crate) fn mark_as_shared(&self) { pub(crate) fn mark_as_shared(&self) {
self.shared.store(true, Ordering::Relaxed); self.shared.store(true, Ordering::SeqCst);
} }
/// Get the one-time keys of the account. /// Get the one-time keys of the account.
@ -574,7 +571,7 @@ impl ReadOnlyAccount {
/// ///
/// Returns an empty error if no keys need to be uploaded. /// Returns an empty error if no keys need to be uploaded.
pub(crate) async fn generate_one_time_keys(&self) -> Result<u64, ()> { pub(crate) async fn generate_one_time_keys(&self) -> Result<u64, ()> {
let count = self.uploaded_key_count() as u64; let count = self.uploaded_key_count();
let max_keys = self.max_one_time_keys().await; let max_keys = self.max_one_time_keys().await;
let max_on_server = (max_keys as u64) / 2; let max_on_server = (max_keys as u64) / 2;
@ -595,7 +592,7 @@ impl ReadOnlyAccount {
return true; return true;
} }
let count = self.uploaded_key_count() as u64; let count = self.uploaded_key_count();
// If we have a known key count, check that we have more than // If we have a known key count, check that we have more than
// max_one_time_Keys() / 2, otherwise tell the client to upload more. // max_one_time_Keys() / 2, otherwise tell the client to upload more.
@ -680,7 +677,7 @@ impl ReadOnlyAccount {
inner: Arc::new(Mutex::new(account)), inner: Arc::new(Mutex::new(account)),
identity_keys: Arc::new(identity_keys), identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::from(pickle.shared)), shared: Arc::new(AtomicBool::from(pickle.shared)),
uploaded_signed_key_count: Arc::new(AtomicI64::new(pickle.uploaded_signed_key_count)), uploaded_signed_key_count: Arc::new(AtomicU64::new(pickle.uploaded_signed_key_count)),
}) })
} }
@ -993,6 +990,8 @@ impl ReadOnlyAccount {
#[cfg(test)] #[cfg(test)]
pub(crate) async fn create_session_for(&self, other: &ReadOnlyAccount) -> (Session, Session) { pub(crate) async fn create_session_for(&self, other: &ReadOnlyAccount) -> (Session, Session) {
use ruma::events::{dummy::DummyToDeviceEventContent, AnyToDeviceEventContent};
other.generate_one_time_keys_helper(1).await; other.generate_one_time_keys_helper(1).await;
let one_time = other.signed_one_time_keys().await.unwrap(); let one_time = other.signed_one_time_keys().await.unwrap();
@ -1003,7 +1002,10 @@ impl ReadOnlyAccount {
other.mark_keys_as_published().await; other.mark_keys_as_published().await;
let message = our_session.encrypt(&device, EventType::Dummy, json!({})).await.unwrap(); let message = our_session
.encrypt(&device, AnyToDeviceEventContent::Dummy(DummyToDeviceEventContent::new()))
.await
.unwrap();
let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme { let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme {
c c
} else { } else {

View File

@ -32,8 +32,8 @@ use ruma::{
}, },
AnySyncRoomEvent, SyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent,
}, },
identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId},
serde::Raw, serde::Raw,
DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
@ -320,11 +320,9 @@ impl InboundGroupSession {
decrypted_object.get_mut("content").map(|c| c.as_object_mut()).flatten() decrypted_object.get_mut("content").map(|c| c.as_object_mut()).flatten()
{ {
if !decrypted_content.contains_key("m.relates_to") { if !decrypted_content.contains_key("m.relates_to") {
if let Some(relation) = &event.content.relates_to { let content = serde_json::to_value(&event.content)?;
decrypted_content.insert( if let Some(relation) = content.as_object().and_then(|o| o.get("m.relates_to")) {
"m.relates_to".to_owned(), decrypted_content.insert("m.relates_to".to_owned(), relation.to_owned());
serde_json::to_value(relation).unwrap_or_default(),
);
} }
} }
} }

View File

@ -34,20 +34,20 @@ use olm_rs::{
errors::OlmGroupSessionError, outbound_group_session::OlmOutboundGroupSession, PicklingMode, errors::OlmGroupSessionError, outbound_group_session::OlmOutboundGroupSession, PicklingMode,
}; };
use ruma::{ use ruma::{
api::client::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
room::{ room::{
encrypted::{EncryptedEventContent, EncryptedEventScheme, MegolmV1AesSha2ContentInit}, encrypted::{EncryptedEventContent, EncryptedEventScheme, MegolmV1AesSha2ContentInit},
encryption::EncryptionEventContent, encryption::EncryptionEventContent,
history_visibility::HistoryVisibility, history_visibility::HistoryVisibility,
message::Relation,
}, },
AnyMessageEventContent, EventContent, room_key::RoomKeyToDeviceEventContent,
AnyMessageEventContent, AnyToDeviceEventContent, EventContent,
}, },
to_device::DeviceIdOrAllDevices,
DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId, DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::json;
use tracing::{debug, error, trace}; use tracing::{debug, error, trace};
use super::{ use super::{
@ -277,13 +277,8 @@ impl OutboundGroupSession {
"type": content.event_type(), "type": content.event_type(),
}); });
let relates_to: Option<Relation> = json_content
.get("content")
.map(|c| c.get("m.relates_to").cloned().map(|r| serde_json::from_value(r).ok()))
.flatten()
.flatten();
let plaintext = json_content.to_string(); let plaintext = json_content.to_string();
let relation = content.relation();
let ciphertext = self.encrypt_helper(plaintext).await; let ciphertext = self.encrypt_helper(plaintext).await;
@ -297,7 +292,7 @@ impl OutboundGroupSession {
EncryptedEventContent::new( EncryptedEventContent::new(
EncryptedEventScheme::MegolmV1AesSha2(encrypted_content), EncryptedEventScheme::MegolmV1AesSha2(encrypted_content),
relates_to, relation,
) )
} }
@ -361,16 +356,15 @@ impl OutboundGroupSession {
session.session_message_index() session.session_message_index()
} }
/// Get the outbound group session key as a json value that can be sent as a pub(crate) async fn as_content(&self) -> AnyToDeviceEventContent {
/// m.room_key. let session_key = self.session_key().await;
pub async fn as_json(&self) -> Value {
json!({ AnyToDeviceEventContent::RoomKey(RoomKeyToDeviceEventContent::new(
"algorithm": EventEncryptionAlgorithm::MegolmV1AesSha2, EventEncryptionAlgorithm::MegolmV1AesSha2,
"room_id": &*self.room_id, self.room_id().to_owned(),
"session_id": &*self.session_id, self.session_id().to_owned(),
"session_key": self.session_key().await, session_key.0.clone(),
"chain_index": self.message_index().await, ))
})
} }
/// Has or will the session be shared with the given user/device pair. /// Has or will the session be shared with the given user/device pair.

View File

@ -61,12 +61,19 @@ where
pub(crate) mod test { pub(crate) mod test {
use std::{collections::BTreeMap, convert::TryInto}; use std::{collections::BTreeMap, convert::TryInto};
use matches::assert_matches;
use olm_rs::session::OlmMessage; use olm_rs::session::OlmMessage;
use ruma::{ use ruma::{
api::client::r0::keys::SignedKey, encryption::SignedKey,
events::forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, room_id, user_id, event_id,
DeviceId, UserId, events::{
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
room::message::{MessageEventContent, Relation, Replacement},
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent,
},
room_id, user_id, DeviceId, UserId,
}; };
use serde_json::json;
use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session}; use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session};
@ -215,6 +222,71 @@ pub(crate) mod test {
assert_eq!(plaintext, inbound.decrypt_helper(ciphertext).await.unwrap().0); assert_eq!(plaintext, inbound.decrypt_helper(ciphertext).await.unwrap().0);
} }
#[tokio::test]
async fn edit_decryption() {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost");
let event_id = event_id!("$1234adfad:asdf");
let (outbound, _) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
outbound.mark_as_shared();
assert!(outbound.shared());
let mut content = MessageEventContent::text_plain("Hello");
content.relates_to = Some(Relation::Replacement(Replacement::new(
event_id.clone(),
MessageEventContent::text_plain("Hello edit").into(),
)));
let inbound = InboundGroupSession::new(
"test_key",
"test_key",
&room_id,
outbound.session_key().await,
None,
)
.unwrap();
assert_eq!(0, inbound.first_known_index());
assert_eq!(outbound.session_id(), inbound.session_id());
let encrypted_content =
outbound.encrypt(AnyMessageEventContent::RoomMessage(content)).await;
let event = json!({
"sender": alice.user_id(),
"event_id": event_id,
"origin_server_ts": 0,
"room_id": room_id,
"type": "m.room.encrypted",
"content": encrypted_content,
})
.to_string();
let event: AnySyncRoomEvent = serde_json::from_str(&event).expect("WHAAAT?!?!?");
let event =
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomEncrypted(event)) = event {
event
} else {
panic!("Invalid event type")
};
let decrypted = inbound.decrypt(&event).await.unwrap().0;
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(e)) =
decrypted.deserialize().unwrap()
{
assert_matches!(e.content.relates_to, Some(Relation::Replacement(_)));
} else {
panic!("Invalid event type")
}
}
#[tokio::test] #[tokio::test]
async fn group_session_export() { async fn group_session_export() {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());

View File

@ -26,12 +26,12 @@ use ruma::{
CiphertextInfo, EncryptedEventContent, EncryptedEventScheme, CiphertextInfo, EncryptedEventContent, EncryptedEventScheme,
OlmV1Curve25519AesSha2Content, OlmV1Curve25519AesSha2Content,
}, },
EventType, AnyToDeviceEventContent, EventContent,
}, },
identifiers::{DeviceId, DeviceKeyAlgorithm, UserId}, DeviceId, DeviceKeyAlgorithm, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::json;
use super::{deserialize_instant, serialize_instant, IdentityKeys}; use super::{deserialize_instant, serialize_instant, IdentityKeys};
use crate::{ use crate::{
@ -105,21 +105,17 @@ impl Session {
/// encrypted, this needs to be the device that was used to create this /// encrypted, this needs to be the device that was used to create this
/// session with. /// session with.
/// ///
/// * `event_type` - The type of the event.
///
/// * `content` - The content of the event. /// * `content` - The content of the event.
pub async fn encrypt( pub async fn encrypt(
&mut self, &mut self,
recipient_device: &ReadOnlyDevice, recipient_device: &ReadOnlyDevice,
event_type: EventType, content: AnyToDeviceEventContent,
content: Value,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<EncryptedEventContent> {
let recipient_signing_key = recipient_device let recipient_signing_key = recipient_device
.get_key(DeviceKeyAlgorithm::Ed25519) .get_key(DeviceKeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?; .ok_or(EventError::MissingSigningKey)?;
let relates_to = let event_type = content.event_type();
content.get("m.relates_to").cloned().and_then(|v| serde_json::from_value(v).ok());
let payload = json!({ let payload = json!({
"sender": self.user_id.as_str(), "sender": self.user_id.as_str(),
@ -149,7 +145,7 @@ impl Session {
content, content,
self.our_identity_keys.curve25519().to_string(), self.our_identity_keys.curve25519().to_string(),
)), )),
relates_to, None,
)) ))
} }

View File

@ -25,16 +25,16 @@ use std::{
use matrix_sdk_common::locks::Mutex; use matrix_sdk_common::locks::Mutex;
use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning}; use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning};
use ruma::{ use ruma::{
api::client::r0::keys::{upload_signatures::Request as SignatureUploadRequest, KeyUsage}, api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest,
encryption::DeviceKeys, encryption::{DeviceKeys, KeyUsage},
DeviceKeyAlgorithm, DeviceKeyId, UserId, DeviceKeyAlgorithm, DeviceKeyId, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Error as JsonError; use serde_json::Error as JsonError;
use crate::{ use crate::{
error::SignatureError, requests::UploadSigningKeysRequest, OwnUserIdentity, ReadOnlyAccount, error::SignatureError, identities::MasterPubkey, requests::UploadSigningKeysRequest,
ReadOnlyDevice, UserIdentity, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentity,
}; };
/// Private cross signing identity. /// Private cross signing identity.
@ -91,6 +91,21 @@ impl PrivateCrossSigningIdentity {
!(has_master && has_user && has_self) !(has_master && has_user && has_self)
} }
/// Can we sign our own devices, i.e. do we have a self signing key.
pub async fn can_sign_devices(&self) -> bool {
self.self_signing_key.lock().await.is_some()
}
/// Do we have the master key.
pub async fn has_master_key(&self) -> bool {
self.master_key.lock().await.is_some()
}
/// Get the public part of the master key, if we have one.
pub async fn master_public_key(&self) -> Option<MasterPubkey> {
self.master_key.lock().await.as_ref().map(|m| m.public_key.to_owned())
}
/// Create a new empty identity. /// Create a new empty identity.
pub(crate) fn empty(user_id: UserId) -> Self { pub(crate) fn empty(user_id: UserId) -> Self {
Self { Self {
@ -102,7 +117,9 @@ impl PrivateCrossSigningIdentity {
} }
} }
pub(crate) async fn as_public_identity(&self) -> Result<OwnUserIdentity, SignatureError> { pub(crate) async fn to_public_identity(
&self,
) -> Result<ReadOnlyOwnUserIdentity, SignatureError> {
let master = self let master = self
.master_key .master_key
.lock() .lock()
@ -127,7 +144,7 @@ impl PrivateCrossSigningIdentity {
.ok_or(SignatureError::MissingSigningKey)? .ok_or(SignatureError::MissingSigningKey)?
.public_key .public_key
.clone(); .clone();
let identity = OwnUserIdentity::new(master, self_signing, user_signing)?; let identity = ReadOnlyOwnUserIdentity::new(master, self_signing, user_signing)?;
identity.mark_as_verified(); identity.mark_as_verified();
Ok(identity) Ok(identity)
@ -136,7 +153,7 @@ impl PrivateCrossSigningIdentity {
/// Sign the given public user identity with this private identity. /// Sign the given public user identity with this private identity.
pub(crate) async fn sign_user( pub(crate) async fn sign_user(
&self, &self,
user_identity: &UserIdentity, user_identity: &ReadOnlyUserIdentity,
) -> Result<SignatureUploadRequest, SignatureError> { ) -> Result<SignatureUploadRequest, SignatureError> {
let signed_keys = self let signed_keys = self
.user_signing_key .user_signing_key
@ -388,11 +405,11 @@ mod test {
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use ruma::{api::client::r0::keys::CrossSigningKey, user_id, UserId}; use ruma::{encryption::CrossSigningKey, user_id, UserId};
use super::{PrivateCrossSigningIdentity, Signing}; use super::{PrivateCrossSigningIdentity, Signing};
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentity}, identities::{ReadOnlyDevice, ReadOnlyUserIdentity},
olm::ReadOnlyAccount, olm::ReadOnlyAccount,
}; };
@ -503,7 +520,7 @@ mod test {
let bob_account = ReadOnlyAccount::new(&user_id!("@bob:localhost"), "DEVICEID".into()); let bob_account = ReadOnlyAccount::new(&user_id!("@bob:localhost"), "DEVICEID".into());
let (bob_private, _, _) = PrivateCrossSigningIdentity::new_with_account(&bob_account).await; let (bob_private, _, _) = PrivateCrossSigningIdentity::new_with_account(&bob_account).await;
let mut bob_public = UserIdentity::from_private(&bob_private).await; let mut bob_public = ReadOnlyUserIdentity::from_private(&bob_private).await;
let user_signing = identity.user_signing_key.lock().await; let user_signing = identity.user_signing_key.lock().await;
let user_signing = user_signing.as_ref().unwrap(); let user_signing = user_signing.as_ref().unwrap();

View File

@ -24,8 +24,7 @@ use olm_rs::pk::OlmPkSigning;
#[cfg(test)] #[cfg(test)]
use olm_rs::{errors::OlmUtilityError, utility::OlmUtility}; use olm_rs::{errors::OlmUtilityError, utility::OlmUtility};
use ruma::{ use ruma::{
api::client::r0::keys::{CrossSigningKey, KeyUsage}, encryption::{CrossSigningKey, DeviceKeys, KeyUsage},
encryption::DeviceKeys,
serde::CanonicalJsonValue, serde::CanonicalJsonValue,
DeviceKeyAlgorithm, DeviceKeyId, UserId, DeviceKeyAlgorithm, DeviceKeyId, UserId,
}; };
@ -38,7 +37,7 @@ use crate::{
error::SignatureError, error::SignatureError,
identities::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, identities::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey},
utilities::{decode_url_safe as decode, encode_url_safe as encode, DecodeError}, utilities::{decode_url_safe as decode, encode_url_safe as encode, DecodeError},
UserIdentity, ReadOnlyUserIdentity,
}; };
const NONCE_SIZE: usize = 12; const NONCE_SIZE: usize = 12;
@ -187,7 +186,7 @@ impl UserSigning {
pub async fn sign_user( pub async fn sign_user(
&self, &self,
user: &UserIdentity, user: &ReadOnlyUserIdentity,
) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> { ) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> {
let user_master: &CrossSigningKey = user.master_key().as_ref(); let user_master: &CrossSigningKey = user.master_key().as_ref();
let signature = self.inner.sign_json(serde_json::to_value(user_master)?).await?; let signature = self.inner.sign_json(serde_json::to_value(user_master)?).await?;

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![allow(missing_docs)]
use std::{collections::BTreeMap, sync::Arc, time::Duration}; use std::{collections::BTreeMap, sync::Arc, time::Duration};
use matrix_sdk_common::uuid::Uuid; use matrix_sdk_common::uuid::Uuid;
@ -27,16 +25,17 @@ use ruma::{
Request as SignatureUploadRequest, Response as SignatureUploadResponse, Request as SignatureUploadRequest, Response as SignatureUploadResponse,
}, },
upload_signing_keys::Response as SigningKeysUploadResponse, upload_signing_keys::Response as SigningKeysUploadResponse,
CrossSigningKey,
}, },
message::send_message_event::Response as RoomMessageResponse, message::send_message_event::Response as RoomMessageResponse,
to_device::{send_event_to_device::Response as ToDeviceResponse, DeviceIdOrAllDevices}, to_device::send_event_to_device::Response as ToDeviceResponse,
}, },
events::{AnyMessageEventContent, EventType}, encryption::CrossSigningKey,
events::{AnyMessageEventContent, AnyToDeviceEventContent, EventContent, EventType},
serde::Raw,
to_device::DeviceIdOrAllDevices,
DeviceIdBox, RoomId, UserId, DeviceIdBox, RoomId, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::value::RawValue as RawJsonValue;
/// Customized version of /// Customized version of
/// `ruma_client_api::r0::to_device::send_event_to_device::Request`, /// `ruma_client_api::r0::to_device::send_event_to_device::Request`,
@ -56,10 +55,68 @@ pub struct ToDeviceRequest {
/// The content's type for this field will be updated in a future /// The content's type for this field will be updated in a future
/// release, until then you can create a value using /// release, until then you can create a value using
/// `serde_json::value::to_raw_value`. /// `serde_json::value::to_raw_value`.
pub messages: BTreeMap<UserId, BTreeMap<DeviceIdOrAllDevices, Box<RawJsonValue>>>, pub messages: BTreeMap<UserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
} }
impl ToDeviceRequest { impl ToDeviceRequest {
/// Create a new owned to-device request
///
/// # Arguments
///
/// * `recipient` - The ID of the user that should receive this to-device
/// event.
///
/// * `recipient_device` - The device that should receive this to-device
/// event, or all devices.
///
/// * `content` - The content of the to-device event.
pub(crate) fn new(
recipient: &UserId,
recipient_device: impl Into<DeviceIdOrAllDevices>,
content: AnyToDeviceEventContent,
) -> Self {
Self::new_with_id(recipient, recipient_device, content, Uuid::new_v4())
}
pub(crate) fn new_for_recipients(
recipient: &UserId,
recipient_devices: Vec<DeviceIdBox>,
content: AnyToDeviceEventContent,
txn_id: Uuid,
) -> Self {
let mut messages = BTreeMap::new();
let event_type = EventType::from(content.event_type());
if recipient_devices.is_empty() {
Self::new(recipient, DeviceIdOrAllDevices::AllDevices, content)
} else {
let device_messages = recipient_devices
.into_iter()
.map(|d| (DeviceIdOrAllDevices::DeviceId(d), Raw::from(content.clone())))
.collect();
messages.insert(recipient.clone(), device_messages);
ToDeviceRequest { event_type, txn_id, messages }
}
}
pub(crate) fn new_with_id(
recipient: &UserId,
recipient_device: impl Into<DeviceIdOrAllDevices>,
content: AnyToDeviceEventContent,
txn_id: Uuid,
) -> Self {
let mut messages = BTreeMap::new();
let mut user_messages = BTreeMap::new();
let event_type = EventType::from(content.event_type());
user_messages.insert(recipient_device.into(), Raw::from(content));
messages.insert(recipient.clone(), user_messages);
ToDeviceRequest { event_type, txn_id, messages }
}
/// Gets the transaction ID as a string. /// Gets the transaction ID as a string.
pub fn txn_id_string(&self) -> String { pub fn txn_id_string(&self) -> String {
self.txn_id.to_string() self.txn_id.to_string()
@ -133,6 +190,8 @@ pub enum OutgoingRequests {
/// Signature upload request, this request is used after a successful device /// Signature upload request, this request is used after a successful device
/// or user verification is done. /// or user verification is done.
SignatureUpload(SignatureUploadRequest), SignatureUpload(SignatureUploadRequest),
/// A room message request, usually for sending in-room interactive
/// verification events.
RoomMessage(RoomMessageRequest), RoomMessage(RoomMessageRequest),
} }
@ -205,9 +264,9 @@ pub enum IncomingResponse<'a> {
/// The cross signing keys upload response, marking our private cross /// The cross signing keys upload response, marking our private cross
/// signing identity as shared. /// signing identity as shared.
SigningKeysUpload(&'a SigningKeysUploadResponse), SigningKeysUpload(&'a SigningKeysUploadResponse),
/// The cross signing keys upload response, marking our private cross /// The cross signing signature upload response.
/// signing identity as shared.
SignatureUpload(&'a SignatureUploadResponse), SignatureUpload(&'a SignatureUploadResponse),
/// A room message response, usually for interactive verifications.
RoomMessage(&'a RoomMessageResponse), RoomMessage(&'a RoomMessageResponse),
} }
@ -270,6 +329,7 @@ impl OutgoingRequest {
} }
} }
/// Customized owned request type for sending out room messages.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RoomMessageRequest { pub struct RoomMessageRequest {
/// The room to send the event to. /// The room to send the event to.
@ -286,13 +346,17 @@ pub struct RoomMessageRequest {
pub content: AnyMessageEventContent, pub content: AnyMessageEventContent,
} }
/// An enum over the different outgoing verification based requests.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum OutgoingVerificationRequest { pub enum OutgoingVerificationRequest {
/// The to-device verification request variant.
ToDevice(ToDeviceRequest), ToDevice(ToDeviceRequest),
/// The in-room verification request variant.
InRoom(RoomMessageRequest), InRoom(RoomMessageRequest),
} }
impl OutgoingVerificationRequest { impl OutgoingVerificationRequest {
/// Get the unique id of this request.
pub fn request_id(&self) -> Uuid { pub fn request_id(&self) -> Uuid {
match self { match self {
OutgoingVerificationRequest::ToDevice(t) => t.txn_id, OutgoingVerificationRequest::ToDevice(t) => t.txn_id,

View File

@ -21,14 +21,14 @@ use dashmap::DashMap;
use futures::future::join_all; use futures::future::join_all;
use matrix_sdk_common::{executor::spawn, uuid::Uuid}; use matrix_sdk_common::{executor::spawn, uuid::Uuid};
use ruma::{ use ruma::{
api::client::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility}, room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility},
AnyMessageEventContent, EventType, AnyMessageEventContent, AnyToDeviceEventContent, EventType,
}, },
serde::Raw,
to_device::DeviceIdOrAllDevices,
DeviceId, DeviceIdBox, RoomId, UserId, DeviceId, DeviceIdBox, RoomId, UserId,
}; };
use serde_json::Value;
use tracing::{debug, info, trace}; use tracing::{debug, info, trace};
use crate::{ use crate::{
@ -146,11 +146,6 @@ impl GroupSessionManager {
let mut changes = Changes::default(); let mut changes = Changes::default();
changes.outbound_group_sessions.push(s.clone()); changes.outbound_group_sessions.push(s.clone());
self.store.save_changes(changes).await?; self.store.save_changes(changes).await?;
} else {
trace!(
request_id = request_id.to_string().as_str(),
"Marking room key share request as sent but session found that owns the given id"
)
} }
Ok(()) Ok(())
@ -229,22 +224,22 @@ impl GroupSessionManager {
/// Encrypt the given content for the given devices and create a to-device /// Encrypt the given content for the given devices and create a to-device
/// requests that sends the encrypted content to them. /// requests that sends the encrypted content to them.
async fn encrypt_session_for( async fn encrypt_session_for(
content: Value, content: AnyToDeviceEventContent,
devices: Vec<Device>, devices: Vec<Device>,
) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> { ) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> {
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
let mut changed_sessions = Vec::new(); let mut changed_sessions = Vec::new();
let encrypt = |device: Device, content: Value| async move { let encrypt = |device: Device, content: AnyToDeviceEventContent| async move {
let mut message = BTreeMap::new(); let mut message = BTreeMap::new();
let encrypted = device.encrypt(EventType::RoomKey, content.clone()).await; let encrypted = device.encrypt(content.clone()).await;
let used_session = match encrypted { let used_session = match encrypted {
Ok((session, encrypted)) => { Ok((session, encrypted)) => {
message.entry(device.user_id().clone()).or_insert_with(BTreeMap::new).insert( message.entry(device.user_id().clone()).or_insert_with(BTreeMap::new).insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()), DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?, Raw::from(AnyToDeviceEventContent::RoomEncrypted(encrypted)),
); );
Some(session) Some(session)
} }
@ -380,7 +375,7 @@ impl GroupSessionManager {
pub async fn encrypt_request( pub async fn encrypt_request(
chunk: Vec<Device>, chunk: Vec<Device>,
content: Value, content: AnyToDeviceEventContent,
outbound: OutboundGroupSession, outbound: OutboundGroupSession,
message_index: u32, message_index: u32,
being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>, being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
@ -417,10 +412,7 @@ impl GroupSessionManager {
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> { ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
debug!( debug!(room_id = room_id.as_str(), "Checking if a room key needs to be shared",);
room_id = room_id.as_str(),
"Checking if a group session needs to be shared for room {}", room_id
);
let encryption_settings = encryption_settings.into(); let encryption_settings = encryption_settings.into();
let history_visibility = encryption_settings.history_visibility.clone(); let history_visibility = encryption_settings.history_visibility.clone();
@ -471,7 +463,7 @@ impl GroupSessionManager {
.flatten() .flatten()
.collect(); .collect();
let key_content = outbound.as_json().await; let key_content = outbound.as_content().await;
let message_index = outbound.message_index().await; let message_index = outbound.message_index().await;
if !devices.is_empty() { if !devices.is_empty() {

View File

@ -17,15 +17,13 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration};
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use matrix_sdk_common::uuid::Uuid; use matrix_sdk_common::uuid::Uuid;
use ruma::{ use ruma::{
api::client::r0::{ api::client::r0::keys::claim_keys::{
keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, Request as KeysClaimRequest, Response as KeysClaimResponse,
to_device::DeviceIdOrAllDevices,
}, },
assign, assign,
events::EventType, events::{dummy::DummyToDeviceEventContent, AnyToDeviceEventContent},
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, UserId, DeviceId, DeviceIdBox, DeviceKeyAlgorithm, UserId,
}; };
use serde_json::{json, value::to_raw_value};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::{ use crate::{
@ -118,28 +116,21 @@ impl SessionManager {
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> { async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self.wedged_devices.get(user_id).map(|d| d.remove(device_id)).flatten().is_some() { if self.wedged_devices.get(user_id).map(|d| d.remove(device_id)).flatten().is_some() {
if let Some(device) = self.store.get_device(user_id, device_id).await? { if let Some(device) = self.store.get_device(user_id, device_id).await? {
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?; let content = AnyToDeviceEventContent::Dummy(DummyToDeviceEventContent::new());
let id = Uuid::new_v4(); let (_, content) = device.encrypt(content).await?;
let mut messages = BTreeMap::new();
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert( let request = ToDeviceRequest::new(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()), device.user_id(),
to_raw_value(&content)?, device.device_id().to_owned(),
AnyToDeviceEventContent::RoomEncrypted(content),
); );
let request = OutgoingRequest { let request = OutgoingRequest {
request_id: id, request_id: request.txn_id,
request: Arc::new( request: Arc::new(request.into()),
ToDeviceRequest {
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
}
.into(),
),
}; };
self.outgoing_to_device_requests.insert(id, request); self.outgoing_to_device_requests.insert(request.request_id, request);
} }
} }

View File

@ -26,7 +26,7 @@ use super::{
Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
key_request::OutgoingKeyRequest, key_request::OutgoingKeyRequest,
olm::{OutboundGroupSession, PrivateCrossSigningIdentity}, olm::{OutboundGroupSession, PrivateCrossSigningIdentity},
}; };
@ -44,7 +44,7 @@ pub struct MemoryStore {
users_for_key_query: Arc<DashSet<UserId>>, users_for_key_query: Arc<DashSet<UserId>>,
olm_hashes: Arc<DashMap<String, DashSet<String>>>, olm_hashes: Arc<DashMap<String, DashSet<String>>>,
devices: DeviceStore, devices: DeviceStore,
identities: Arc<DashMap<UserId, UserIdentities>>, identities: Arc<DashMap<UserId, ReadOnlyUserIdentities>>,
outgoing_key_requests: Arc<DashMap<Uuid, OutgoingKeyRequest>>, outgoing_key_requests: Arc<DashMap<Uuid, OutgoingKeyRequest>>,
key_requests_by_info: Arc<DashMap<String, Uuid>>, key_requests_by_info: Arc<DashMap<String, Uuid>>,
} }
@ -215,7 +215,7 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.user_devices(user_id)) Ok(self.devices.user_devices(user_id))
} }
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> { async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<ReadOnlyUserIdentities>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self.identities.get(user_id).map(|i| i.clone())) Ok(self.identities.get(user_id).map(|i| i.clone()))
} }

View File

@ -56,11 +56,8 @@ pub use memorystore::MemoryStore;
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
pub use pickle_key::{EncryptedPickleKey, PickleKey}; pub use pickle_key::{EncryptedPickleKey, PickleKey};
use ruma::{ use ruma::{
events::room_key_request::RequestedKeyInfo, events::room_key_request::RequestedKeyInfo, identifiers::Error as IdentifierValidationError,
identifiers::{ DeviceId, DeviceIdBox, DeviceKeyAlgorithm, RoomId, UserId,
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId,
UserId,
},
}; };
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use thiserror::Error; use thiserror::Error;
@ -69,7 +66,10 @@ use thiserror::Error;
pub use self::sled::SledStore; pub use self::sled::SledStore;
use crate::{ use crate::{
error::SessionUnpicklingError, error::SessionUnpicklingError,
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, identities::{
user::{OwnUserIdentity, UserIdentities, UserIdentity},
Device, ReadOnlyDevice, ReadOnlyUserIdentities, UserDevices,
},
key_request::OutgoingKeyRequest, key_request::OutgoingKeyRequest,
olm::{ olm::{
InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
@ -112,8 +112,8 @@ pub struct Changes {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
#[allow(missing_docs)] #[allow(missing_docs)]
pub struct IdentityChanges { pub struct IdentityChanges {
pub new: Vec<UserIdentities>, pub new: Vec<ReadOnlyUserIdentities>,
pub changed: Vec<UserIdentities>, pub changed: Vec<ReadOnlyUserIdentities>,
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -218,8 +218,8 @@ impl Store {
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<Option<Device>> { ) -> Result<Option<Device>> {
let own_identity = let own_identity =
self.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten(); self.inner.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
let device_owner_identity = self.get_user_identity(user_id).await?; let device_owner_identity = self.inner.get_user_identity(user_id).await?;
Ok(self.inner.get_device(user_id, device_id).await?.map(|d| Device { Ok(self.inner.get_device(user_id, device_id).await?.map(|d| Device {
inner: d, inner: d,
@ -229,6 +229,20 @@ impl Store {
device_owner_identity, device_owner_identity,
})) }))
} }
pub async fn get_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
Ok(self.inner.get_user_identity(user_id).await?.map(|i| match i {
ReadOnlyUserIdentities::Own(i) => OwnUserIdentity {
inner: i,
verification_machine: self.verification_machine.clone(),
}
.into(),
ReadOnlyUserIdentities::Other(i) => {
UserIdentity { inner: i, verification_machine: self.verification_machine.clone() }
.into()
}
}))
}
} }
impl Deref for Store { impl Deref for Store {
@ -239,7 +253,7 @@ impl Deref for Store {
} }
} }
#[derive(Error, Debug)] #[derive(Debug, Error)]
/// The crypto store's error type. /// The crypto store's error type.
pub enum CryptoStoreError { pub enum CryptoStoreError {
/// The account that owns the sessions, group sessions, and devices wasn't /// The account that owns the sessions, group sessions, and devices wasn't
@ -391,7 +405,7 @@ pub trait CryptoStore: AsyncTraitDeps {
/// # Arguments /// # Arguments
/// ///
/// * `user_id` - The user for which we should get the identity. /// * `user_id` - The user for which we should get the identity.
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>>; async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<ReadOnlyUserIdentities>>;
/// Check if a hash for an Olm message stored in the database. /// Check if a hash for an Olm message stored in the database.
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>; async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>;

View File

@ -35,7 +35,7 @@ use super::{
ReadOnlyAccount, Result, Session, ReadOnlyAccount, Result, Session,
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
key_request::OutgoingKeyRequest, key_request::OutgoingKeyRequest,
olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity}, olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity},
}; };
@ -669,7 +669,7 @@ impl CryptoStore for SledStore {
.collect() .collect()
} }
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> { async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<ReadOnlyUserIdentities>> {
Ok(self Ok(self
.identities .identities
.get(user_id.encode())? .get(user_id.encode())?
@ -757,9 +757,8 @@ mod test {
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use olm_rs::outbound_group_session::OlmOutboundGroupSession; use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use ruma::{ use ruma::{
api::client::r0::keys::SignedKey, encryption::SignedKey, events::room_key_request::RequestedKeyInfo, room_id, user_id,
events::room_key_request::RequestedKeyInfo, DeviceId, EventEncryptionAlgorithm, UserId,
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
}; };
use tempfile::tempdir; use tempfile::tempdir;

View File

@ -17,13 +17,17 @@ use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::uuid::Uuid; use matrix_sdk_common::uuid::Uuid;
use ruma::{DeviceId, UserId}; use ruma::{DeviceId, UserId};
use tracing::trace;
use super::{event_enums::OutgoingContent, sas::content_to_request, Sas, Verification}; use super::{event_enums::OutgoingContent, Sas, Verification};
use crate::{OutgoingRequest, RoomMessageRequest}; use crate::{
OutgoingRequest, OutgoingVerificationRequest, QrVerification, RoomMessageRequest,
ToDeviceRequest,
};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct VerificationCache { pub struct VerificationCache {
verification: Arc<DashMap<String, Verification>>, verification: Arc<DashMap<UserId, DashMap<String, Verification>>>,
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>, outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
} }
@ -35,41 +39,75 @@ impl VerificationCache {
#[cfg(test)] #[cfg(test)]
#[allow(dead_code)] #[allow(dead_code)]
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.verification.is_empty() self.verification.iter().all(|m| m.is_empty())
}
pub fn insert(&self, verification: impl Into<Verification>) {
let verification = verification.into();
self.verification
.entry(verification.other_user().to_owned())
.or_insert_with(DashMap::new)
.insert(verification.flow_id().to_owned(), verification);
} }
pub fn insert_sas(&self, sas: Sas) { pub fn insert_sas(&self, sas: Sas) {
self.verification.insert(sas.flow_id().as_str().to_string(), sas.into()); self.insert(sas);
}
pub fn insert_qr(&self, qr: QrVerification) {
self.insert(qr)
}
pub fn get_qr(&self, sender: &UserId, flow_id: &str) -> Option<QrVerification> {
self.get(sender, flow_id).and_then(|v| {
if let Verification::QrV1(qr) = v {
Some(qr)
} else {
None
}
})
}
pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
self.verification.get(sender).and_then(|m| m.get(flow_id).map(|v| v.clone()))
} }
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> { pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
self.outgoing_requests.iter().map(|r| (*r).clone()).collect() self.outgoing_requests.iter().map(|r| (*r).clone()).collect()
} }
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> { pub fn garbage_collect(&self) -> Vec<OutgoingVerificationRequest> {
self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled())); for user_verification in self.verification.iter() {
user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
}
self.verification.retain(|_, m| !m.is_empty());
self.verification self.verification
.iter() .iter()
.filter_map(|s| { .flat_map(|v| {
#[allow(irrefutable_let_patterns)] let requests: Vec<OutgoingVerificationRequest> = v
if let Verification::SasV1(s) = s.value() { .value()
s.cancel_if_timed_out().map(|r| OutgoingRequest { .iter()
request_id: r.request_id(), .filter_map(|s| {
request: Arc::new(r.into()), if let Verification::SasV1(s) = s.value() {
s.cancel_if_timed_out()
} else {
None
}
}) })
} else { .collect();
None
} requests
}) })
.collect() .collect()
} }
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> { pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.verification.get(transaction_id).and_then(|v| { self.get(user_id, flow_id).and_then(|v| {
#[allow(irrefutable_let_patterns)] if let Verification::SasV1(sas) = v {
if let Verification::SasV1(sas) = v.value() { Some(sas)
Some(sas.clone())
} else { } else {
None None
} }
@ -77,9 +115,16 @@ impl VerificationCache {
} }
pub fn add_request(&self, request: OutgoingRequest) { pub fn add_request(&self, request: OutgoingRequest) {
trace!("Adding an outgoing verification request {:?}", request);
self.outgoing_requests.insert(request.request_id, request); self.outgoing_requests.insert(request.request_id, request);
} }
pub fn add_verification_request(&self, request: OutgoingVerificationRequest) {
let request =
OutgoingRequest { request_id: request.request_id(), request: Arc::new(request.into()) };
self.add_request(request);
}
pub fn queue_up_content( pub fn queue_up_content(
&self, &self,
recipient: &UserId, recipient: &UserId,
@ -88,7 +133,7 @@ impl VerificationCache {
) { ) {
match content { match content {
OutgoingContent::ToDevice(c) => { OutgoingContent::ToDevice(c) => {
let request = content_to_request(recipient, recipient_device.to_owned(), c); let request = ToDeviceRequest::new(recipient, recipient_device.to_owned(), c);
let request_id = request.txn_id; let request_id = request.txn_id;
let request = OutgoingRequest { request_id, request: Arc::new(request.into()) }; let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };

View File

@ -33,8 +33,8 @@ use ruma::{
room::message::{KeyVerificationRequestEventContent, MessageType}, room::message::{KeyVerificationRequestEventContent, MessageType},
AnyMessageEvent, AnyMessageEventContent, AnyToDeviceEvent, AnyToDeviceEventContent, AnyMessageEvent, AnyMessageEventContent, AnyToDeviceEvent, AnyToDeviceEventContent,
}, },
identifiers::{DeviceId, RoomId, UserId},
serde::CanonicalJsonValue, serde::CanonicalJsonValue,
DeviceId, MilliSecondsSinceUnixEpoch, RoomId, UserId,
}; };
use super::FlowId; use super::FlowId;
@ -53,6 +53,20 @@ impl AnyEvent<'_> {
} }
} }
pub fn timestamp(&self) -> Option<&MilliSecondsSinceUnixEpoch> {
match self {
AnyEvent::Room(e) => Some(e.origin_server_ts()),
AnyEvent::ToDevice(e) => match e {
AnyToDeviceEvent::KeyVerificationRequest(e) => Some(&e.content.timestamp),
_ => None,
},
}
}
pub fn is_room_event(&self) -> bool {
matches!(self, AnyEvent::Room(_))
}
pub fn verification_content(&self) -> Option<AnyVerificationContent> { pub fn verification_content(&self) -> Option<AnyVerificationContent> {
match self { match self {
AnyEvent::Room(e) => match e { AnyEvent::Room(e) => match e {
@ -64,8 +78,7 @@ impl AnyEvent<'_> {
| AnyMessageEvent::RoomEncrypted(_) | AnyMessageEvent::RoomEncrypted(_)
| AnyMessageEvent::RoomMessageFeedback(_) | AnyMessageEvent::RoomMessageFeedback(_)
| AnyMessageEvent::RoomRedaction(_) | AnyMessageEvent::RoomRedaction(_)
| AnyMessageEvent::Sticker(_) | AnyMessageEvent::Sticker(_) => None,
| AnyMessageEvent::Custom(_) => None,
AnyMessageEvent::RoomMessage(m) => { AnyMessageEvent::RoomMessage(m) => {
if let MessageType::VerificationRequest(v) = &m.content.msgtype { if let MessageType::VerificationRequest(v) = &m.content.msgtype {
Some(RequestContent::from(v).into()) Some(RequestContent::from(v).into())
@ -90,14 +103,14 @@ impl AnyEvent<'_> {
AnyMessageEvent::KeyVerificationDone(e) => { AnyMessageEvent::KeyVerificationDone(e) => {
Some(DoneContent::from(&e.content).into()) Some(DoneContent::from(&e.content).into())
} }
_ => None,
}, },
AnyEvent::ToDevice(e) => match e { AnyEvent::ToDevice(e) => match e {
AnyToDeviceEvent::Dummy(_) AnyToDeviceEvent::Dummy(_)
| AnyToDeviceEvent::RoomKey(_) | AnyToDeviceEvent::RoomKey(_)
| AnyToDeviceEvent::RoomKeyRequest(_) | AnyToDeviceEvent::RoomKeyRequest(_)
| AnyToDeviceEvent::ForwardedRoomKey(_) | AnyToDeviceEvent::ForwardedRoomKey(_)
| AnyToDeviceEvent::RoomEncrypted(_) | AnyToDeviceEvent::RoomEncrypted(_) => None,
| AnyToDeviceEvent::Custom(_) => None,
AnyToDeviceEvent::KeyVerificationRequest(e) => { AnyToDeviceEvent::KeyVerificationRequest(e) => {
Some(RequestContent::from(&e.content).into()) Some(RequestContent::from(&e.content).into())
} }
@ -122,6 +135,7 @@ impl AnyEvent<'_> {
AnyToDeviceEvent::KeyVerificationDone(e) => { AnyToDeviceEvent::KeyVerificationDone(e) => {
Some(DoneContent::from(&e.content).into()) Some(DoneContent::from(&e.content).into())
} }
_ => None,
}, },
} }
} }
@ -163,30 +177,30 @@ impl TryFrom<&AnyMessageEvent> for FlowId {
| AnyMessageEvent::RoomEncrypted(_) | AnyMessageEvent::RoomEncrypted(_)
| AnyMessageEvent::RoomMessageFeedback(_) | AnyMessageEvent::RoomMessageFeedback(_)
| AnyMessageEvent::RoomRedaction(_) | AnyMessageEvent::RoomRedaction(_)
| AnyMessageEvent::Sticker(_) | AnyMessageEvent::Sticker(_) => Err(()),
| AnyMessageEvent::Custom(_) => Err(()),
AnyMessageEvent::KeyVerificationReady(e) => { AnyMessageEvent::KeyVerificationReady(e) => {
Ok(FlowId::from((&e.room_id, &e.content.relation.event_id))) Ok(FlowId::from((&e.room_id, &e.content.relates_to.event_id)))
} }
AnyMessageEvent::RoomMessage(e) => Ok(FlowId::from((&e.room_id, &e.event_id))), AnyMessageEvent::RoomMessage(e) => Ok(FlowId::from((&e.room_id, &e.event_id))),
AnyMessageEvent::KeyVerificationStart(e) => { AnyMessageEvent::KeyVerificationStart(e) => {
Ok(FlowId::from((&e.room_id, &e.content.relation.event_id))) Ok(FlowId::from((&e.room_id, &e.content.relates_to.event_id)))
} }
AnyMessageEvent::KeyVerificationCancel(e) => { AnyMessageEvent::KeyVerificationCancel(e) => {
Ok(FlowId::from((&e.room_id, &e.content.relation.event_id))) Ok(FlowId::from((&e.room_id, &e.content.relates_to.event_id)))
} }
AnyMessageEvent::KeyVerificationAccept(e) => { AnyMessageEvent::KeyVerificationAccept(e) => {
Ok(FlowId::from((&e.room_id, &e.content.relation.event_id))) Ok(FlowId::from((&e.room_id, &e.content.relates_to.event_id)))
} }
AnyMessageEvent::KeyVerificationKey(e) => { AnyMessageEvent::KeyVerificationKey(e) => {
Ok(FlowId::from((&e.room_id, &e.content.relation.event_id))) Ok(FlowId::from((&e.room_id, &e.content.relates_to.event_id)))
} }
AnyMessageEvent::KeyVerificationMac(e) => { AnyMessageEvent::KeyVerificationMac(e) => {
Ok(FlowId::from((&e.room_id, &e.content.relation.event_id))) Ok(FlowId::from((&e.room_id, &e.content.relates_to.event_id)))
} }
AnyMessageEvent::KeyVerificationDone(e) => { AnyMessageEvent::KeyVerificationDone(e) => {
Ok(FlowId::from((&e.room_id, &e.content.relation.event_id))) Ok(FlowId::from((&e.room_id, &e.content.relates_to.event_id)))
} }
_ => Err(()),
} }
} }
} }
@ -200,8 +214,7 @@ impl TryFrom<&AnyToDeviceEvent> for FlowId {
| AnyToDeviceEvent::RoomKey(_) | AnyToDeviceEvent::RoomKey(_)
| AnyToDeviceEvent::RoomKeyRequest(_) | AnyToDeviceEvent::RoomKeyRequest(_)
| AnyToDeviceEvent::ForwardedRoomKey(_) | AnyToDeviceEvent::ForwardedRoomKey(_)
| AnyToDeviceEvent::RoomEncrypted(_) | AnyToDeviceEvent::RoomEncrypted(_) => Err(()),
| AnyToDeviceEvent::Custom(_) => Err(()),
AnyToDeviceEvent::KeyVerificationRequest(e) => { AnyToDeviceEvent::KeyVerificationRequest(e) => {
Ok(FlowId::from(e.content.transaction_id.to_owned())) Ok(FlowId::from(e.content.transaction_id.to_owned()))
} }
@ -226,6 +239,7 @@ impl TryFrom<&AnyToDeviceEvent> for FlowId {
AnyToDeviceEvent::KeyVerificationDone(e) => { AnyToDeviceEvent::KeyVerificationDone(e) => {
Ok(FlowId::from(e.content.transaction_id.to_owned())) Ok(FlowId::from(e.content.transaction_id.to_owned()))
} }
_ => Err(()),
} }
} }
} }
@ -365,6 +379,33 @@ try_from_outgoing_content!(MacContent, KeyVerificationMac);
try_from_outgoing_content!(CancelContent, KeyVerificationCancel); try_from_outgoing_content!(CancelContent, KeyVerificationCancel);
try_from_outgoing_content!(DoneContent, KeyVerificationDone); try_from_outgoing_content!(DoneContent, KeyVerificationDone);
impl<'a> TryFrom<&'a OutgoingContent> for RequestContent<'a> {
type Error = ();
fn try_from(value: &'a OutgoingContent) -> Result<Self, Self::Error> {
match value {
OutgoingContent::Room(_, c) => {
if let AnyMessageEventContent::RoomMessage(m) = c {
if let MessageType::VerificationRequest(c) = &m.msgtype {
Ok(Self::Room(c))
} else {
Err(())
}
} else {
Err(())
}
}
OutgoingContent::ToDevice(c) => {
if let AnyToDeviceEventContent::KeyVerificationRequest(c) = c {
Ok(Self::ToDevice(c))
} else {
Err(())
}
}
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub enum StartContent<'a> { pub enum StartContent<'a> {
ToDevice(&'a StartToDeviceEventContent), ToDevice(&'a StartToDeviceEventContent),
@ -382,7 +423,7 @@ impl<'a> StartContent<'a> {
pub fn flow_id(&self) -> &str { pub fn flow_id(&self) -> &str {
match self { match self {
Self::ToDevice(c) => &c.transaction_id, Self::ToDevice(c) => &c.transaction_id,
Self::Room(c) => c.relation.event_id.as_str(), Self::Room(c) => c.relates_to.event_id.as_str(),
} }
} }
@ -434,7 +475,7 @@ impl<'a> DoneContent<'a> {
pub fn flow_id(&self) -> &str { pub fn flow_id(&self) -> &str {
match self { match self {
Self::ToDevice(c) => &c.transaction_id, Self::ToDevice(c) => &c.transaction_id,
Self::Room(c) => c.relation.event_id.as_str(), Self::Room(c) => c.relates_to.event_id.as_str(),
} }
} }
} }
@ -449,7 +490,7 @@ impl AcceptContent<'_> {
pub fn flow_id(&self) -> &str { pub fn flow_id(&self) -> &str {
match self { match self {
Self::ToDevice(c) => &c.transaction_id, Self::ToDevice(c) => &c.transaction_id,
Self::Room(c) => c.relation.event_id.as_str(), Self::Room(c) => c.relates_to.event_id.as_str(),
} }
} }
@ -480,7 +521,7 @@ impl KeyContent<'_> {
pub fn flow_id(&self) -> &str { pub fn flow_id(&self) -> &str {
match self { match self {
Self::ToDevice(c) => &c.transaction_id, Self::ToDevice(c) => &c.transaction_id,
Self::Room(c) => c.relation.event_id.as_str(), Self::Room(c) => c.relates_to.event_id.as_str(),
} }
} }
@ -502,7 +543,7 @@ impl MacContent<'_> {
pub fn flow_id(&self) -> &str { pub fn flow_id(&self) -> &str {
match self { match self {
Self::ToDevice(c) => &c.transaction_id, Self::ToDevice(c) => &c.transaction_id,
Self::Room(c) => c.relation.event_id.as_str(), Self::Room(c) => c.relates_to.event_id.as_str(),
} }
} }
@ -567,7 +608,7 @@ impl OwnedStartContent {
pub fn flow_id(&self) -> FlowId { pub fn flow_id(&self) -> FlowId {
match self { match self {
Self::ToDevice(c) => FlowId::ToDevice(c.transaction_id.clone()), Self::ToDevice(c) => FlowId::ToDevice(c.transaction_id.clone()),
Self::Room(r, c) => FlowId::InRoom(r.clone(), c.relation.event_id.clone()), Self::Room(r, c) => FlowId::InRoom(r.clone(), c.relates_to.event_id.clone()),
} }
} }
@ -651,94 +692,84 @@ impl From<(RoomId, AnyMessageEventContent)> for OutgoingContent {
} }
} }
#[cfg(test)]
use crate::{OutgoingRequest, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest}; use crate::{OutgoingRequest, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest};
#[cfg(test)] impl TryFrom<OutgoingVerificationRequest> for OutgoingContent {
impl From<OutgoingVerificationRequest> for OutgoingContent { type Error = String;
fn from(request: OutgoingVerificationRequest) -> Self {
fn try_from(request: OutgoingVerificationRequest) -> Result<Self, Self::Error> {
match request { match request {
OutgoingVerificationRequest::ToDevice(r) => Self::try_from(r).unwrap(), OutgoingVerificationRequest::ToDevice(r) => Self::try_from(r),
OutgoingVerificationRequest::InRoom(r) => Self::from(r), OutgoingVerificationRequest::InRoom(r) => Ok(Self::from(r)),
} }
} }
} }
#[cfg(test)]
impl From<RoomMessageRequest> for OutgoingContent { impl From<RoomMessageRequest> for OutgoingContent {
fn from(value: RoomMessageRequest) -> Self { fn from(value: RoomMessageRequest) -> Self {
(value.room_id, value.content).into() (value.room_id, value.content).into()
} }
} }
#[cfg(test)]
impl TryFrom<ToDeviceRequest> for OutgoingContent { impl TryFrom<ToDeviceRequest> for OutgoingContent {
type Error = (); type Error = String;
fn try_from(value: ToDeviceRequest) -> Result<Self, Self::Error> { fn try_from(request: ToDeviceRequest) -> Result<Self, Self::Error> {
use ruma::events::EventType; use ruma::events::EventType;
use serde_json::Value; use serde_json::Value;
let json: Value = serde_json::from_str( let json: Value = serde_json::from_str(
value request
.messages .messages
.values() .values()
.next() .next()
.and_then(|m| m.values().next().map(|j| j.get())) .and_then(|m| m.values().next())
.ok_or(())?, .map(|c| c.json().get())
.ok_or_else(|| "Content is missing from the request".to_owned())?,
) )
.map_err(|_| ())?; .map_err(|e| e.to_string())?;
match value.event_type { let content = match request.event_type {
EventType::KeyVerificationRequest => { EventType::KeyVerificationStart => AnyToDeviceEventContent::KeyVerificationStart(
Ok(AnyToDeviceEventContent::KeyVerificationRequest( serde_json::from_value(json).map_err(|e| e.to_string())?,
serde_json::from_value(json).map_err(|_| ())?, ),
) EventType::KeyVerificationKey => AnyToDeviceEventContent::KeyVerificationKey(
.into()) serde_json::from_value(json).map_err(|e| e.to_string())?,
} ),
EventType::KeyVerificationReady => Ok(AnyToDeviceEventContent::KeyVerificationReady( EventType::KeyVerificationAccept => AnyToDeviceEventContent::KeyVerificationAccept(
serde_json::from_value(json).map_err(|_| ())?, serde_json::from_value(json).map_err(|e| e.to_string())?,
) ),
.into()), EventType::KeyVerificationMac => AnyToDeviceEventContent::KeyVerificationMac(
EventType::KeyVerificationDone => Ok(AnyToDeviceEventContent::KeyVerificationDone( serde_json::from_value(json).map_err(|e| e.to_string())?,
serde_json::from_value(json).map_err(|_| ())?, ),
) EventType::KeyVerificationCancel => AnyToDeviceEventContent::KeyVerificationCancel(
.into()), serde_json::from_value(json).map_err(|e| e.to_string())?,
EventType::KeyVerificationStart => Ok(AnyToDeviceEventContent::KeyVerificationStart( ),
serde_json::from_value(json).map_err(|_| ())?, EventType::KeyVerificationReady => AnyToDeviceEventContent::KeyVerificationReady(
) serde_json::from_value(json).map_err(|e| e.to_string())?,
.into()), ),
EventType::KeyVerificationKey => Ok(AnyToDeviceEventContent::KeyVerificationKey( EventType::KeyVerificationDone => AnyToDeviceEventContent::KeyVerificationDone(
serde_json::from_value(json).map_err(|_| ())?, serde_json::from_value(json).map_err(|e| e.to_string())?,
) ),
.into()), EventType::KeyVerificationRequest => AnyToDeviceEventContent::KeyVerificationRequest(
EventType::KeyVerificationAccept => Ok(AnyToDeviceEventContent::KeyVerificationAccept( serde_json::from_value(json).map_err(|e| e.to_string())?,
serde_json::from_value(json).map_err(|_| ())?, ),
) e => return Err(format!("Unsupported event type {}", e)),
.into()), };
EventType::KeyVerificationMac => Ok(AnyToDeviceEventContent::KeyVerificationMac(
serde_json::from_value(json).map_err(|_| ())?, Ok(content.into())
)
.into()),
EventType::KeyVerificationCancel => Ok(AnyToDeviceEventContent::KeyVerificationCancel(
serde_json::from_value(json).map_err(|_| ())?,
)
.into()),
_ => Err(()),
}
} }
} }
#[cfg(test)]
impl TryFrom<OutgoingRequest> for OutgoingContent { impl TryFrom<OutgoingRequest> for OutgoingContent {
type Error = (); type Error = String;
fn try_from(value: OutgoingRequest) -> Result<Self, ()> { fn try_from(value: OutgoingRequest) -> Result<Self, Self::Error> {
match value.request() { match value.request() {
crate::OutgoingRequests::KeysUpload(_) => Err(()), crate::OutgoingRequests::KeysUpload(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::KeysQuery(_) => Err(()), crate::OutgoingRequests::KeysQuery(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::ToDeviceRequest(r) => Self::try_from(r.clone()), crate::OutgoingRequests::ToDeviceRequest(r) => Self::try_from(r.clone()),
crate::OutgoingRequests::SignatureUpload(_) => Err(()), crate::OutgoingRequests::SignatureUpload(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::RoomMessage(r) => Ok(Self::from(r.clone())), crate::OutgoingRequests::RoomMessage(r) => Ok(Self::from(r.clone())),
} }
} }

View File

@ -12,25 +12,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{convert::TryFrom, sync::Arc}; use std::{
convert::{TryFrom, TryInto},
sync::Arc,
};
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::{locks::Mutex, uuid::Uuid}; use matrix_sdk_common::{locks::Mutex, uuid::Uuid};
use ruma::{DeviceId, UserId}; use ruma::{
use tracing::{info, warn}; events::{
key::verification::VerificationMethod, AnyToDeviceEvent, AnyToDeviceEventContent,
ToDeviceEvent,
},
serde::Raw,
DeviceId, DeviceIdBox, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId,
};
use tracing::{info, trace, warn};
use super::{ use super::{
cache::VerificationCache, cache::VerificationCache,
event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent}, event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent},
requests::VerificationRequest, requests::VerificationRequest,
sas::{content_to_request, Sas}, sas::Sas,
FlowId, VerificationResult, FlowId, Verification, VerificationResult,
}; };
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
requests::OutgoingRequest, requests::OutgoingRequest,
store::{CryptoStore, CryptoStoreError}, store::{CryptoStore, CryptoStoreError},
OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest, OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentity,
RoomMessageRequest, ToDeviceRequest,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -39,7 +50,7 @@ pub struct VerificationMachine {
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>, private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
pub(crate) store: Arc<dyn CryptoStore>, pub(crate) store: Arc<dyn CryptoStore>,
verifications: VerificationCache, verifications: VerificationCache,
requests: Arc<DashMap<String, VerificationRequest>>, requests: Arc<DashMap<UserId, DashMap<String, VerificationRequest>>>,
} }
impl VerificationMachine { impl VerificationMachine {
@ -57,6 +68,65 @@ impl VerificationMachine {
} }
} }
pub(crate) fn own_user_id(&self) -> &UserId {
self.account.user_id()
}
pub(crate) fn own_device_id(&self) -> &DeviceId {
self.account.device_id()
}
pub(crate) async fn request_to_device_verification(
&self,
user_id: &UserId,
recipient_devices: Vec<DeviceIdBox>,
methods: Option<Vec<VerificationMethod>>,
) -> (VerificationRequest, OutgoingVerificationRequest) {
let flow_id = FlowId::from(Uuid::new_v4().to_string());
let verification = VerificationRequest::new(
self.verifications.clone(),
self.account.clone(),
self.private_identity.lock().await.clone(),
self.store.clone(),
flow_id,
user_id,
recipient_devices,
methods,
);
self.insert_request(verification.clone());
let request = verification.request_to_device();
(verification, request.into())
}
pub async fn request_verification(
&self,
identity: &ReadOnlyUserIdentity,
room_id: &RoomId,
request_event_id: &EventId,
methods: Option<Vec<VerificationMethod>>,
) -> VerificationRequest {
let flow_id = FlowId::InRoom(room_id.to_owned(), request_event_id.to_owned());
let request = VerificationRequest::new(
self.verifications.clone(),
self.account.clone(),
self.private_identity.lock().await.clone(),
self.store.clone(),
flow_id,
identity.user_id(),
vec![],
methods,
);
self.insert_request(request.clone());
request
}
pub async fn start_sas( pub async fn start_sas(
&self, &self,
device: ReadOnlyDevice, device: ReadOnlyDevice,
@ -71,6 +141,8 @@ impl VerificationMachine {
self.store.clone(), self.store.clone(),
identity, identity,
None, None,
true,
None,
); );
let request = match content { let request = match content {
@ -79,7 +151,7 @@ impl VerificationMachine {
} }
OutgoingContent::ToDevice(c) => { OutgoingContent::ToDevice(c) => {
let request = let request =
content_to_request(device.user_id(), device.device_id().to_owned(), c); ToDeviceRequest::new(device.user_id(), device.device_id().to_owned(), c);
self.verifications.insert_sas(sas.clone()); self.verifications.insert_sas(sas.clone());
@ -90,12 +162,60 @@ impl VerificationMachine {
Ok((sas, request)) Ok((sas, request))
} }
pub fn get_request(&self, flow_id: impl AsRef<str>) -> Option<VerificationRequest> { pub fn get_request(
self.requests.get(flow_id.as_ref()).map(|s| s.clone()) &self,
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
self.requests.get(user_id).and_then(|v| v.get(flow_id.as_ref()).map(|s| s.clone()))
} }
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> { pub fn get_requests(&self, user_id: &UserId) -> Vec<VerificationRequest> {
self.verifications.get_sas(transaction_id) self.requests
.get(user_id)
.map(|v| v.iter().map(|i| i.value().clone()).collect())
.unwrap_or_default()
}
fn insert_request(&self, request: VerificationRequest) {
self.requests
.entry(request.other_user().to_owned())
.or_insert_with(DashMap::new)
.insert(request.flow_id().as_str().to_owned(), request);
}
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.verifications.get(user_id, flow_id)
}
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.verifications.get_sas(user_id, flow_id)
}
#[cfg(not(target_arch = "wasm32"))]
fn is_timestamp_valid(timestamp: &MilliSecondsSinceUnixEpoch) -> bool {
use ruma::{uint, UInt};
// The event should be ignored if the event is older than 10 minutes
let old_timestamp_threshold: UInt = uint!(600);
// The event should be ignored if the event is 5 minutes or more into the
// future.
let timestamp_threshold: UInt = uint!(300);
let timestamp = timestamp.as_secs();
let now = MilliSecondsSinceUnixEpoch::now().as_secs();
!(now.saturating_sub(timestamp) > old_timestamp_threshold
|| timestamp.saturating_sub(now) > timestamp_threshold)
}
#[cfg(target_arch = "wasm32")]
fn is_timestamp_valid(timestamp: &MilliSecondsSinceUnixEpoch) -> bool {
// TODO the non-wasm method with the same name uses
// `MilliSecondsSinceUnixEpoch::now()` which internally uses
// `SystemTime::now()` this panics under WASM, thus we're returning here
// true for now.
true
} }
fn queue_up_content( fn queue_up_content(
@ -115,12 +235,40 @@ impl VerificationMachine {
self.verifications.outgoing_requests() self.verifications.outgoing_requests()
} }
pub fn garbage_collect(&self) { pub fn garbage_collect(&self) -> Vec<Raw<AnyToDeviceEvent>> {
self.requests.retain(|_, r| !(r.is_done() || r.is_cancelled())); let mut events = vec![];
for request in self.verifications.garbage_collect() { for user_verification in self.requests.iter() {
self.verifications.add_request(request) user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled()));
} }
self.requests.retain(|_, v| !v.is_empty());
let mut requests: Vec<OutgoingVerificationRequest> = self
.requests
.iter()
.flat_map(|v| {
let requests: Vec<OutgoingVerificationRequest> =
v.value().iter().filter_map(|v| v.cancel_if_timed_out()).collect();
requests
})
.collect();
requests.extend(self.verifications.garbage_collect().into_iter());
for request in requests {
if let Ok(OutgoingContent::ToDevice(AnyToDeviceEventContent::KeyVerificationCancel(
content,
))) = request.clone().try_into()
{
let event = ToDeviceEvent { content, sender: self.account.user_id().to_owned() };
events.push(AnyToDeviceEvent::KeyVerificationCancel(event).into());
}
self.verifications.add_verification_request(request)
}
events
} }
async fn mark_sas_as_done( async fn mark_sas_as_done(
@ -175,6 +323,14 @@ impl VerificationMachine {
); );
}; };
let event_sent_from_us = |event: &AnyEvent<'_>, from_device: &DeviceId| {
if event.sender() == self.account.user_id() {
from_device == self.account.device_id() || event.is_room_event()
} else {
false
}
};
if let Some(content) = event.verification_content() { if let Some(content) = event.verification_content() {
match &content { match &content {
AnyVerificationContent::Request(r) => { AnyVerificationContent::Request(r) => {
@ -184,40 +340,71 @@ impl VerificationMachine {
"Received a new verification request", "Received a new verification request",
); );
let request = VerificationRequest::from_request( if let Some(timestamp) = event.timestamp() {
self.verifications.clone(), if Self::is_timestamp_valid(timestamp) {
self.account.clone(), if !event_sent_from_us(&event, r.from_device()) {
self.private_identity.lock().await.clone(), let request = VerificationRequest::from_request(
self.store.clone(), self.verifications.clone(),
event.sender(), self.account.clone(),
flow_id, self.private_identity.lock().await.clone(),
r, self.store.clone(),
); event.sender(),
flow_id,
r,
);
self.requests.insert(request.flow_id().as_str().to_owned(), request); self.insert_request(request);
} else {
trace!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
"The received verification request was sent by us, ignoring it",
);
}
} else {
trace!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
timestamp =? timestamp,
"The received verification request was too old or too far into the future",
);
}
} else {
warn!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
"The key verification request didn't contain a valid timestamp"
);
}
} }
AnyVerificationContent::Cancel(c) => { AnyVerificationContent::Cancel(c) => {
if let Some(verification) = self.get_request(flow_id.as_str()) { if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_cancel(event.sender(), c); verification.receive_cancel(event.sender(), c);
} }
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) { if let Some(verification) =
// This won't produce an outgoing content self.get_verification(event.sender(), flow_id.as_str())
let _ = sas.receive_any_event(event.sender(), &content); {
match verification {
Verification::SasV1(sas) => {
// This won't produce an outgoing content
let _ = sas.receive_any_event(event.sender(), &content);
}
Verification::QrV1(qr) => qr.receive_cancel(event.sender(), c),
}
} }
} }
AnyVerificationContent::Ready(c) => { AnyVerificationContent::Ready(c) => {
if let Some(request) = self.requests.get(flow_id.as_str()) { if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
if request.flow_id() == &flow_id { if request.flow_id() == &flow_id {
// TODO remove this unwrap. request.receive_ready(event.sender(), c);
request.receive_ready(event.sender(), c).unwrap();
} else { } else {
flow_id_mismatch(); flow_id_mismatch();
} }
} }
} }
AnyVerificationContent::Start(c) => { AnyVerificationContent::Start(c) => {
if let Some(request) = self.requests.get(flow_id.as_str()) { if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
if request.flow_id() == &flow_id { if request.flow_id() == &flow_id {
request.receive_start(event.sender(), c).await? request.receive_start(event.sender(), c).await?
} else { } else {
@ -240,6 +427,7 @@ impl VerificationMachine {
private_identity, private_identity,
device, device,
identity, identity,
None,
false, false,
) { ) {
Ok(sas) => { Ok(sas) => {
@ -255,7 +443,7 @@ impl VerificationMachine {
} }
} }
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => { AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) { if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
if sas.flow_id() == &flow_id { if sas.flow_id() == &flow_id {
if let Some(content) = sas.receive_any_event(event.sender(), &content) { if let Some(content) = sas.receive_any_event(event.sender(), &content) {
self.queue_up_content( self.queue_up_content(
@ -270,7 +458,7 @@ impl VerificationMachine {
} }
} }
AnyVerificationContent::Mac(_) => { AnyVerificationContent::Mac(_) => {
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) { if let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) {
if s.flow_id() == &flow_id { if s.flow_id() == &flow_id {
let content = s.receive_any_event(event.sender(), &content); let content = s.receive_any_event(event.sender(), &content);
@ -283,16 +471,30 @@ impl VerificationMachine {
} }
} }
AnyVerificationContent::Done(c) => { AnyVerificationContent::Done(c) => {
if let Some(verification) = self.get_request(flow_id.as_str()) { if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_done(event.sender(), c); verification.receive_done(event.sender(), c);
} }
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) { match self.get_verification(event.sender(), flow_id.as_str()) {
let content = s.receive_any_event(event.sender(), &content); Some(Verification::SasV1(sas)) => {
let content = sas.receive_any_event(event.sender(), &content);
if s.is_done() { if sas.is_done() {
self.mark_sas_as_done(s, content).await?; self.mark_sas_as_done(sas, content).await?;
}
} }
Some(Verification::QrV1(qr)) => {
let (cancellation, request) = qr.receive_done(&c).await?;
if let Some(c) = cancellation {
self.verifications.add_request(c.into())
}
if let Some(s) = request {
self.verifications.add_request(s.into())
}
}
None => (),
} }
} }
} }
@ -363,6 +565,8 @@ mod test {
bob_store, bob_store,
None, None,
None, None,
true,
None,
); );
machine machine
@ -385,7 +589,7 @@ mod test {
async fn full_flow() { async fn full_flow() {
let (alice_machine, bob) = setup_verification_machine().await; let (alice_machine, bob) = setup_verification_machine().await;
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
let request = alice.accept().unwrap(); let request = alice.accept().unwrap();
@ -431,7 +635,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn timing_out() { async fn timing_out() {
let (alice_machine, bob) = setup_verification_machine().await; let (alice_machine, bob) = setup_verification_machine().await;
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
assert!(!alice.timed_out()); assert!(!alice.timed_out());
assert!(alice_machine.verifications.outgoing_requests().is_empty()); assert!(alice_machine.verifications.outgoing_requests().is_empty());

View File

@ -15,6 +15,7 @@
mod cache; mod cache;
mod event_enums; mod event_enums;
mod machine; mod machine;
mod qrcode;
mod requests; mod requests;
mod sas; mod sas;
@ -22,6 +23,7 @@ use std::sync::Arc;
use event_enums::OutgoingContent; use event_enums::OutgoingContent;
pub use machine::VerificationMachine; pub use machine::VerificationMachine;
pub use qrcode::QrVerification;
pub use requests::VerificationRequest; pub use requests::VerificationRequest;
use ruma::{ use ruma::{
api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest, api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest,
@ -41,25 +43,83 @@ use tracing::{error, info, trace, warn};
use crate::{ use crate::{
error::SignatureError, error::SignatureError,
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
store::{Changes, CryptoStore, DeviceChanges}, store::{Changes, CryptoStore},
CryptoStoreError, LocalTrust, ReadOnlyDevice, UserIdentities, CryptoStoreError, LocalTrust, ReadOnlyDevice, ReadOnlyUserIdentities,
}; };
/// An enum over the different verification types the SDK supports.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Verification { pub enum Verification {
/// The `m.sas.v1` verification variant.
SasV1(Sas), SasV1(Sas),
/// The `m.qr_code.*.v1` verification variant.
QrV1(QrVerification),
} }
impl Verification { impl Verification {
pub fn is_done(&self) -> bool { /// Try to deconstruct this verification enum into a SAS verification.
match self { pub fn sas_v1(self) -> Option<Sas> {
Verification::SasV1(s) => s.is_done(), if let Verification::SasV1(sas) = self {
Some(sas)
} else {
None
} }
} }
/// Try to deconstruct this verification enum into a QR code verification.
pub fn qr_v1(self) -> Option<QrVerification> {
if let Verification::QrV1(qr) = self {
Some(qr)
} else {
None
}
}
/// Has this verification finished.
pub fn is_done(&self) -> bool {
match self {
Verification::SasV1(s) => s.is_done(),
Verification::QrV1(qr) => qr.is_done(),
}
}
/// Get the ID that uniquely identifies this verification flow.
pub fn flow_id(&self) -> &str {
match self {
Verification::SasV1(s) => s.flow_id().as_str(),
Verification::QrV1(qr) => qr.flow_id().as_str(),
}
}
/// Has the verification been cancelled.
pub fn is_cancelled(&self) -> bool { pub fn is_cancelled(&self) -> bool {
match self { match self {
Verification::SasV1(s) => s.is_cancelled(), Verification::SasV1(s) => s.is_cancelled(),
Verification::QrV1(qr) => qr.is_cancelled(),
}
}
/// Get our own user id that is participating in this verification.
pub fn user_id(&self) -> &UserId {
match self {
Verification::SasV1(v) => v.user_id(),
Verification::QrV1(v) => v.user_id(),
}
}
/// Get the other user id that is participating in this verification.
pub fn other_user(&self) -> &UserId {
match self {
Verification::SasV1(s) => s.other_user_id(),
Verification::QrV1(qr) => qr.other_user_id(),
}
}
/// Is this a verification verifying a device that belongs to us.
pub fn is_self_verification(&self) -> bool {
match self {
Verification::SasV1(v) => v.is_self_verification(),
Verification::QrV1(v) => v.is_self_verification(),
} }
} }
} }
@ -70,6 +130,12 @@ impl From<Sas> for Verification {
} }
} }
impl From<QrVerification> for Verification {
fn from(qr: QrVerification) -> Self {
Self::QrV1(qr)
}
}
/// The verification state indicating that the verification finished /// The verification state indicating that the verification finished
/// successfully. /// successfully.
/// ///
@ -78,7 +144,7 @@ impl From<Sas> for Verification {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Done { pub struct Done {
verified_devices: Arc<[ReadOnlyDevice]>, verified_devices: Arc<[ReadOnlyDevice]>,
verified_master_keys: Arc<[UserIdentities]>, verified_master_keys: Arc<[ReadOnlyUserIdentities]>,
} }
impl Done { impl Done {
@ -99,14 +165,47 @@ impl Done {
} }
} }
/// Information about the cancellation of a verification request or verification
/// flow.
#[derive(Clone, Debug)]
pub struct CancelInfo {
cancelled_by_us: bool,
cancel_code: CancelCode,
reason: &'static str,
}
impl CancelInfo {
/// Get the human readable reason of the cancellation.
pub fn reason(&self) -> &'static str {
&self.reason
}
/// Get the `CancelCode` that cancelled this verification.
pub fn cancel_code(&self) -> &CancelCode {
&self.cancel_code
}
/// Was the verification cancelled by us?
pub fn cancelled_by_us(&self) -> bool {
self.cancelled_by_us
}
}
impl From<Cancelled> for CancelInfo {
fn from(c: Cancelled) -> Self {
Self { cancelled_by_us: c.cancelled_by_us, cancel_code: c.cancel_code, reason: c.reason }
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Cancelled { pub struct Cancelled {
cancelled_by_us: bool,
cancel_code: CancelCode, cancel_code: CancelCode,
reason: &'static str, reason: &'static str,
} }
impl Cancelled { impl Cancelled {
fn new(code: CancelCode) -> Self { fn new(cancelled_by_us: bool, code: CancelCode) -> Self {
let reason = match code { let reason = match code {
CancelCode::Accepted => { CancelCode::Accepted => {
"A m.key.verification.request was accepted by a different device." "A m.key.verification.request was accepted by a different device."
@ -126,7 +225,7 @@ impl Cancelled {
_ => "Unknown cancel reason", _ => "Unknown cancel reason",
}; };
Self { cancel_code: code, reason } Self { cancelled_by_us, cancel_code: code, reason }
} }
pub fn as_content(&self, flow_id: &FlowId) -> OutgoingContent { pub fn as_content(&self, flow_id: &FlowId) -> OutgoingContent {
@ -210,14 +309,22 @@ pub struct IdentitiesBeingVerified {
private_identity: PrivateCrossSigningIdentity, private_identity: PrivateCrossSigningIdentity,
store: Arc<dyn CryptoStore>, store: Arc<dyn CryptoStore>,
device_being_verified: ReadOnlyDevice, device_being_verified: ReadOnlyDevice,
identity_being_verified: Option<UserIdentities>, identity_being_verified: Option<ReadOnlyUserIdentities>,
} }
impl IdentitiesBeingVerified { impl IdentitiesBeingVerified {
async fn can_sign_devices(&self) -> bool {
self.private_identity.can_sign_devices().await
}
fn user_id(&self) -> &UserId { fn user_id(&self) -> &UserId {
self.private_identity.user_id() self.private_identity.user_id()
} }
fn is_self_verification(&self) -> bool {
self.user_id() == self.other_user_id()
}
fn other_user_id(&self) -> &UserId { fn other_user_id(&self) -> &UserId {
self.device_being_verified.user_id() self.device_being_verified.user_id()
} }
@ -233,11 +340,20 @@ impl IdentitiesBeingVerified {
pub async fn mark_as_done( pub async fn mark_as_done(
&self, &self,
verified_devices: Option<&[ReadOnlyDevice]>, verified_devices: Option<&[ReadOnlyDevice]>,
verified_identities: Option<&[UserIdentities]>, verified_identities: Option<&[ReadOnlyUserIdentities]>,
) -> Result<VerificationResult, CryptoStoreError> { ) -> Result<VerificationResult, CryptoStoreError> {
if let Some(device) = self.mark_device_as_verified(verified_devices).await? { let device = self.mark_device_as_verified(verified_devices).await?;
let identity = self.mark_identity_as_verified(verified_identities).await?; let identity = self.mark_identity_as_verified(verified_identities).await?;
if device.is_none() && identity.is_none() {
// Something wen't wrong if nothing was verified, we use key
// mismatch here, since it's the closest to nothing was verified
return Ok(VerificationResult::Cancel(CancelCode::KeyMismatch));
}
let mut changes = Changes::default();
let signature_request = if let Some(device) = device {
// We only sign devices of our own user here. // We only sign devices of our own user here.
let signature_request = if device.user_id() == self.user_id() { let signature_request = if device.user_id() == self.user_id() {
match self.private_identity.sign_device(&device).await { match self.private_identity.sign_device(&device).await {
@ -266,83 +382,79 @@ impl IdentitiesBeingVerified {
None None
}; };
let mut changes = Changes { changes.devices.changed.push(device);
devices: DeviceChanges { changed: vec![device], ..Default::default() }, signature_request
..Default::default() } else {
}; None
};
let identity_signature_request = if let Some(i) = identity { let identity_signature_request = if let Some(i) = identity {
// We only sign other users here. // We only sign other users here.
let request = if let Some(i) = i.other() { let request = if let Some(i) = i.other() {
// Signing can fail if the user signing key is missing. // Signing can fail if the user signing key is missing.
match self.private_identity.sign_user(i).await { match self.private_identity.sign_user(i).await {
Ok(r) => Some(r), Ok(r) => Some(r),
Err(SignatureError::MissingSigningKey) => { Err(SignatureError::MissingSigningKey) => {
warn!( warn!(
"Can't sign the public cross signing keys for {}, \ "Can't sign the public cross signing keys for {}, \
no private user signing key found", no private user signing key found",
i.user_id() i.user_id()
); );
None None
}
Err(e) => {
error!(
"Error signing the public cross signing keys for {} {:?}",
i.user_id(),
e
);
None
}
} }
} else { Err(e) => {
None error!(
}; "Error signing the public cross signing keys for {} {:?}",
i.user_id(),
changes.identities.changed.push(i); e
);
request None
}
}
} else { } else {
None None
}; };
// If there are two signature upload requests, merge them. Otherwise changes.identities.changed.push(i);
// use the one we have or None. request
//
// Realistically at most one request will be used but let's make
// this future proof.
let merged_request = if let Some(mut r) = signature_request {
if let Some(user_request) = identity_signature_request {
r.signed_keys.extend(user_request.signed_keys);
Some(r)
} else {
Some(r)
}
} else {
identity_signature_request
};
// TODO store the signature upload request as well.
self.store.save_changes(changes).await?;
Ok(merged_request
.map(VerificationResult::SignatureUpload)
.unwrap_or(VerificationResult::Ok))
} else { } else {
Ok(VerificationResult::Cancel(CancelCode::UserMismatch)) None
} };
// If there are two signature upload requests, merge them. Otherwise
// use the one we have or None.
//
// Realistically at most one request will be used but let's make
// this future proof.
let merged_request = if let Some(mut r) = signature_request {
if let Some(user_request) = identity_signature_request {
r.signed_keys.extend(user_request.signed_keys);
Some(r)
} else {
Some(r)
}
} else {
identity_signature_request
};
// TODO store the signature upload request as well.
self.store.save_changes(changes).await?;
Ok(merged_request
.map(VerificationResult::SignatureUpload)
.unwrap_or(VerificationResult::Ok))
} }
async fn mark_identity_as_verified( async fn mark_identity_as_verified(
&self, &self,
verified_identities: Option<&[UserIdentities]>, verified_identities: Option<&[ReadOnlyUserIdentities]>,
) -> Result<Option<UserIdentities>, CryptoStoreError> { ) -> Result<Option<ReadOnlyUserIdentities>, CryptoStoreError> {
// If there wasn't an identity available during the verification flow // If there wasn't an identity available during the verification flow
// return early as there's nothing to do. // return early as there's nothing to do.
if self.identity_being_verified.is_none() { if self.identity_being_verified.is_none() {
return Ok(None); return Ok(None);
} }
// TODO signal an error, e.g. when the identity got deleted so we don't
// verify/save the device either.
let identity = self.store.get_user_identity(self.other_user_id()).await?; let identity = self.store.get_user_identity(self.other_user_id()).await?;
if let Some(identity) = identity { if let Some(identity) = identity {
@ -357,7 +469,7 @@ impl IdentitiesBeingVerified {
"Marking the user identity of as verified." "Marking the user identity of as verified."
); );
if let UserIdentities::Own(i) = &identity { if let ReadOnlyUserIdentities::Own(i) = &identity {
i.mark_as_verified(); i.mark_as_verified();
} }
@ -445,11 +557,12 @@ impl IdentitiesBeingVerified {
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
use std::convert::TryInto;
use ruma::{ use ruma::{
events::{AnyToDeviceEvent, AnyToDeviceEventContent, EventType, ToDeviceEvent}, events::{AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent},
UserId, UserId,
}; };
use serde_json::Value;
use super::event_enums::OutgoingContent; use super::event_enums::OutgoingContent;
use crate::{ use crate::{
@ -461,7 +574,8 @@ pub(crate) mod test {
sender: &UserId, sender: &UserId,
request: &OutgoingVerificationRequest, request: &OutgoingVerificationRequest,
) -> AnyToDeviceEvent { ) -> AnyToDeviceEvent {
let content = get_content_from_request(request); let content =
request.to_owned().try_into().expect("Can't fetch content out of the request");
wrap_any_to_device_content(sender, content) wrap_any_to_device_content(sender, content)
} }
@ -510,36 +624,4 @@ pub(crate) mod test {
_ => unreachable!(), _ => unreachable!(),
} }
} }
pub(crate) fn get_content_from_request(
request: &OutgoingVerificationRequest,
) -> OutgoingContent {
let request =
if let OutgoingVerificationRequest::ToDevice(r) = request { r } else { unreachable!() };
let json: Value = serde_json::from_str(
request.messages.values().next().unwrap().values().next().unwrap().get(),
)
.unwrap();
match request.event_type {
EventType::KeyVerificationStart => {
AnyToDeviceEventContent::KeyVerificationStart(serde_json::from_value(json).unwrap())
}
EventType::KeyVerificationKey => {
AnyToDeviceEventContent::KeyVerificationKey(serde_json::from_value(json).unwrap())
}
EventType::KeyVerificationAccept => AnyToDeviceEventContent::KeyVerificationAccept(
serde_json::from_value(json).unwrap(),
),
EventType::KeyVerificationMac => {
AnyToDeviceEventContent::KeyVerificationMac(serde_json::from_value(json).unwrap())
}
EventType::KeyVerificationCancel => AnyToDeviceEventContent::KeyVerificationCancel(
serde_json::from_value(json).unwrap(),
),
_ => unreachable!(),
}
.into()
}
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -14,17 +14,15 @@
use std::{collections::BTreeMap, convert::TryInto}; use std::{collections::BTreeMap, convert::TryInto};
use matrix_sdk_common::uuid::Uuid;
use olm_rs::sas::OlmSas; use olm_rs::sas::OlmSas;
use ruma::{ use ruma::{
api::client::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
key::verification::{ key::verification::{
cancel::CancelCode, cancel::CancelCode,
mac::{MacEventContent, MacToDeviceEventContent}, mac::{MacEventContent, MacToDeviceEventContent},
Relation, Relation,
}, },
AnyMessageEventContent, AnyToDeviceEventContent, EventType, AnyMessageEventContent, AnyToDeviceEventContent,
}, },
DeviceKeyAlgorithm, DeviceKeyId, UserId, DeviceKeyAlgorithm, DeviceKeyId, UserId,
}; };
@ -33,17 +31,17 @@ use tracing::{trace, warn};
use super::{FlowId, OutgoingContent}; use super::{FlowId, OutgoingContent};
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
utilities::encode, utilities::encode,
verification::event_enums::{MacContent, StartContent}, verification::event_enums::{MacContent, StartContent},
ReadOnlyAccount, ToDeviceRequest, ReadOnlyAccount,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SasIds { pub struct SasIds {
pub account: ReadOnlyAccount, pub account: ReadOnlyAccount,
pub other_device: ReadOnlyDevice, pub other_device: ReadOnlyDevice,
pub other_identity: Option<UserIdentities>, pub other_identity: Option<ReadOnlyUserIdentities>,
} }
/// Calculate the commitment for a accept event from the public key and the /// Calculate the commitment for a accept event from the public key and the
@ -184,7 +182,7 @@ pub fn receive_mac_event(
flow_id: &str, flow_id: &str,
sender: &UserId, sender: &UserId,
content: &MacContent, content: &MacContent,
) -> Result<(Vec<ReadOnlyDevice>, Vec<UserIdentities>), CancelCode> { ) -> Result<(Vec<ReadOnlyDevice>, Vec<ReadOnlyUserIdentities>), CancelCode> {
let mut verified_devices = Vec::new(); let mut verified_devices = Vec::new();
let mut verified_identities = Vec::new(); let mut verified_identities = Vec::new();
@ -527,34 +525,6 @@ fn bytes_to_decimal(bytes: Vec<u8>) -> (u16, u16, u16) {
(first + 1000, second + 1000, third + 1000) (first + 1000, second + 1000, third + 1000)
} }
pub fn content_to_request(
recipient: &UserId,
recipient_device: impl Into<DeviceIdOrAllDevices>,
content: AnyToDeviceEventContent,
) -> ToDeviceRequest {
let mut messages = BTreeMap::new();
let mut user_messages = BTreeMap::new();
user_messages.insert(
recipient_device.into(),
serde_json::value::to_raw_value(&content).expect("Can't serialize to-device content"),
);
messages.insert(recipient.clone(), user_messages);
let event_type = match content {
AnyToDeviceEventContent::KeyVerificationAccept(_) => EventType::KeyVerificationAccept,
AnyToDeviceEventContent::KeyVerificationStart(_) => EventType::KeyVerificationStart,
AnyToDeviceEventContent::KeyVerificationKey(_) => EventType::KeyVerificationKey,
AnyToDeviceEventContent::KeyVerificationMac(_) => EventType::KeyVerificationMac,
AnyToDeviceEventContent::KeyVerificationCancel(_) => EventType::KeyVerificationCancel,
AnyToDeviceEventContent::KeyVerificationReady(_) => EventType::KeyVerificationReady,
AnyToDeviceEventContent::KeyVerificationDone(_) => EventType::KeyVerificationDone,
_ => unreachable!(),
};
ToDeviceRequest { txn_id: Uuid::new_v4(), event_type, messages }
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use proptest::prelude::*; use proptest::prelude::*;

View File

@ -24,11 +24,12 @@ use ruma::{
use super::{ use super::{
sas_state::{ sas_state::{
Accepted, Confirmed, Created, KeyReceived, MacReceived, SasState, Started, WaitingForDone, Accepted, Confirmed, Created, KeyReceived, MacReceived, SasState, Started, WaitingForDone,
WeAccepted,
}, },
FlowId, FlowId,
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
verification::{ verification::{
event_enums::{AnyVerificationContent, OutgoingContent, OwnedAcceptContent, StartContent}, event_enums::{AnyVerificationContent, OutgoingContent, OwnedAcceptContent, StartContent},
Cancelled, Done, Cancelled, Done,
@ -41,6 +42,7 @@ pub enum InnerSas {
Created(SasState<Created>), Created(SasState<Created>),
Started(SasState<Started>), Started(SasState<Started>),
Accepted(SasState<Accepted>), Accepted(SasState<Accepted>),
WeAccepted(SasState<WeAccepted>),
KeyReceived(SasState<KeyReceived>), KeyReceived(SasState<KeyReceived>),
Confirmed(SasState<Confirmed>), Confirmed(SasState<Confirmed>),
MacReceived(SasState<MacReceived>), MacReceived(SasState<MacReceived>),
@ -53,7 +55,7 @@ impl InnerSas {
pub fn start( pub fn start(
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
transaction_id: Option<String>, transaction_id: Option<String>,
) -> (InnerSas, OutgoingContent) { ) -> (InnerSas, OutgoingContent) {
let sas = SasState::<Created>::new(account, other_device, other_identity, transaction_id); let sas = SasState::<Created>::new(account, other_device, other_identity, transaction_id);
@ -61,6 +63,34 @@ impl InnerSas {
(InnerSas::Created(sas), content.into()) (InnerSas::Created(sas), content.into())
} }
pub fn started_from_request(&self) -> bool {
match self {
InnerSas::Created(s) => s.started_from_request,
InnerSas::Started(s) => s.started_from_request,
InnerSas::WeAccepted(s) => s.started_from_request,
InnerSas::Accepted(s) => s.started_from_request,
InnerSas::KeyReceived(s) => s.started_from_request,
InnerSas::Confirmed(s) => s.started_from_request,
InnerSas::MacReceived(s) => s.started_from_request,
InnerSas::WaitingForDone(s) => s.started_from_request,
InnerSas::Done(s) => s.started_from_request,
InnerSas::Cancelled(s) => s.started_from_request,
}
}
pub fn has_been_accepted(&self) -> bool {
match self {
InnerSas::Created(_) | InnerSas::Started(_) | InnerSas::Cancelled(_) => false,
InnerSas::Accepted(_)
| InnerSas::WeAccepted(_)
| InnerSas::KeyReceived(_)
| InnerSas::Confirmed(_)
| InnerSas::MacReceived(_)
| InnerSas::WaitingForDone(_)
| InnerSas::Done(_) => true,
}
}
pub fn supports_emoji(&self) -> bool { pub fn supports_emoji(&self) -> bool {
match self { match self {
InnerSas::Created(_) => false, InnerSas::Created(_) => false,
@ -69,6 +99,11 @@ impl InnerSas {
.accepted_protocols .accepted_protocols
.short_auth_string .short_auth_string
.contains(&ShortAuthenticationString::Emoji), .contains(&ShortAuthenticationString::Emoji),
InnerSas::WeAccepted(s) => s
.state
.accepted_protocols
.short_auth_string
.contains(&ShortAuthenticationString::Emoji),
InnerSas::Accepted(s) => s InnerSas::Accepted(s) => s
.state .state
.accepted_protocols .accepted_protocols
@ -96,7 +131,7 @@ impl InnerSas {
room_id: RoomId, room_id: RoomId,
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
) -> (InnerSas, OutgoingContent) { ) -> (InnerSas, OutgoingContent) {
let sas = SasState::<Created>::new_in_room( let sas = SasState::<Created>::new_in_room(
room_id, room_id,
@ -114,7 +149,7 @@ impl InnerSas {
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
flow_id: FlowId, flow_id: FlowId,
content: &StartContent, content: &StartContent,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
started_from_request: bool, started_from_request: bool,
) -> Result<InnerSas, OutgoingContent> { ) -> Result<InnerSas, OutgoingContent> {
match SasState::<Started>::from_start_event( match SasState::<Started>::from_start_event(
@ -130,9 +165,14 @@ impl InnerSas {
} }
} }
pub fn accept(&self) -> Option<OwnedAcceptContent> { pub fn accept(
self,
methods: Vec<ShortAuthenticationString>,
) -> Option<(InnerSas, OwnedAcceptContent)> {
if let InnerSas::Started(s) = self { if let InnerSas::Started(s) = self {
Some(s.as_content()) let sas = s.into_accepted(methods);
let content = sas.as_content();
Some((InnerSas::WeAccepted(sas), content))
} else { } else {
None None
} }
@ -151,17 +191,25 @@ impl InnerSas {
InnerSas::MacReceived(s) => s.set_creation_time(time), InnerSas::MacReceived(s) => s.set_creation_time(time),
InnerSas::Done(s) => s.set_creation_time(time), InnerSas::Done(s) => s.set_creation_time(time),
InnerSas::WaitingForDone(s) => s.set_creation_time(time), InnerSas::WaitingForDone(s) => s.set_creation_time(time),
InnerSas::WeAccepted(s) => s.set_creation_time(time),
} }
} }
pub fn cancel(self, code: CancelCode) -> (InnerSas, Option<OutgoingContent>) { pub fn cancel(
self,
cancelled_by_us: bool,
code: CancelCode,
) -> (InnerSas, Option<OutgoingContent>) {
let sas = match self { let sas = match self {
InnerSas::Created(s) => s.cancel(code), InnerSas::Created(s) => s.cancel(cancelled_by_us, code),
InnerSas::Started(s) => s.cancel(code), InnerSas::Started(s) => s.cancel(cancelled_by_us, code),
InnerSas::Accepted(s) => s.cancel(code), InnerSas::Accepted(s) => s.cancel(cancelled_by_us, code),
InnerSas::KeyReceived(s) => s.cancel(code), InnerSas::WeAccepted(s) => s.cancel(cancelled_by_us, code),
InnerSas::MacReceived(s) => s.cancel(code), InnerSas::KeyReceived(s) => s.cancel(cancelled_by_us, code),
_ => return (self, None), InnerSas::MacReceived(s) => s.cancel(cancelled_by_us, code),
InnerSas::Confirmed(s) => s.cancel(cancelled_by_us, code),
InnerSas::WaitingForDone(s) => s.cancel(cancelled_by_us, code),
InnerSas::Done(_) | InnerSas::Cancelled(_) => return (self, None),
}; };
let content = sas.as_content(); let content = sas.as_content();
@ -216,7 +264,7 @@ impl InnerSas {
} }
} }
AnyVerificationContent::Cancel(c) => { AnyVerificationContent::Cancel(c) => {
let (sas, _) = self.cancel(c.cancel_code().to_owned()); let (sas, _) = self.cancel(false, c.cancel_code().to_owned());
(sas, None) (sas, None)
} }
AnyVerificationContent::Key(c) => match self { AnyVerificationContent::Key(c) => match self {
@ -227,7 +275,7 @@ impl InnerSas {
(InnerSas::Cancelled(s), Some(content)) (InnerSas::Cancelled(s), Some(content))
} }
}, },
InnerSas::Started(s) => match s.into_key_received(sender, c) { InnerSas::WeAccepted(s) => match s.into_key_received(sender, c) {
Ok(s) => { Ok(s) => {
let content = s.as_content(); let content = s.as_content();
(InnerSas::KeyReceived(s), Some(content)) (InnerSas::KeyReceived(s), Some(content))
@ -298,6 +346,10 @@ impl InnerSas {
matches!(self, InnerSas::Cancelled(_)) matches!(self, InnerSas::Cancelled(_))
} }
pub fn have_we_confirmed(&self) -> bool {
matches!(self, InnerSas::Confirmed(_) | InnerSas::WaitingForDone(_) | InnerSas::Done(_))
}
pub fn timed_out(&self) -> bool { pub fn timed_out(&self) -> bool {
match self { match self {
InnerSas::Created(s) => s.timed_out(), InnerSas::Created(s) => s.timed_out(),
@ -309,6 +361,7 @@ impl InnerSas {
InnerSas::MacReceived(s) => s.timed_out(), InnerSas::MacReceived(s) => s.timed_out(),
InnerSas::WaitingForDone(s) => s.timed_out(), InnerSas::WaitingForDone(s) => s.timed_out(),
InnerSas::Done(s) => s.timed_out(), InnerSas::Done(s) => s.timed_out(),
InnerSas::WeAccepted(s) => s.timed_out(),
} }
} }
@ -323,6 +376,7 @@ impl InnerSas {
InnerSas::MacReceived(s) => s.verification_flow_id.clone(), InnerSas::MacReceived(s) => s.verification_flow_id.clone(),
InnerSas::WaitingForDone(s) => s.verification_flow_id.clone(), InnerSas::WaitingForDone(s) => s.verification_flow_id.clone(),
InnerSas::Done(s) => s.verification_flow_id.clone(), InnerSas::Done(s) => s.verification_flow_id.clone(),
InnerSas::WeAccepted(s) => s.verification_flow_id.clone(),
} }
} }
@ -358,7 +412,7 @@ impl InnerSas {
} }
} }
pub fn verified_identities(&self) -> Option<Arc<[UserIdentities]>> { pub fn verified_identities(&self) -> Option<Arc<[ReadOnlyUserIdentities]>> {
if let InnerSas::Done(s) = self { if let InnerSas::Done(s) = self {
Some(s.verified_identities()) Some(s.verified_identities())
} else { } else {

View File

@ -20,17 +20,12 @@ use std::sync::{Arc, Mutex};
#[cfg(test)] #[cfg(test)]
use std::time::Instant; use std::time::Instant;
pub use helpers::content_to_request;
use inner_sas::InnerSas; use inner_sas::InnerSas;
use matrix_sdk_common::uuid::Uuid; use matrix_sdk_common::uuid::Uuid;
use ruma::{ use ruma::{
api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest, api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest,
events::{ events::{
key::verification::{ key::verification::{cancel::CancelCode, ShortAuthenticationString},
accept::{AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent},
cancel::CancelCode,
ShortAuthenticationString,
},
AnyMessageEventContent, AnyToDeviceEventContent, AnyMessageEventContent, AnyToDeviceEventContent,
}, },
DeviceId, EventId, RoomId, UserId, DeviceId, EventId, RoomId, UserId,
@ -39,23 +34,26 @@ use tracing::trace;
use super::{ use super::{
event_enums::{AnyVerificationContent, OutgoingContent, OwnedAcceptContent, StartContent}, event_enums::{AnyVerificationContent, OutgoingContent, OwnedAcceptContent, StartContent},
FlowId, IdentitiesBeingVerified, VerificationResult, requests::RequestHandle,
CancelInfo, FlowId, IdentitiesBeingVerified, VerificationResult,
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
requests::{OutgoingVerificationRequest, RoomMessageRequest}, requests::{OutgoingVerificationRequest, RoomMessageRequest},
store::{CryptoStore, CryptoStoreError}, store::{CryptoStore, CryptoStoreError},
ReadOnlyAccount, ToDeviceRequest, ReadOnlyAccount, ToDeviceRequest,
}; };
#[derive(Clone, Debug)]
/// Short authentication string object. /// Short authentication string object.
#[derive(Clone, Debug)]
pub struct Sas { pub struct Sas {
inner: Arc<Mutex<InnerSas>>, inner: Arc<Mutex<InnerSas>>,
account: ReadOnlyAccount, account: ReadOnlyAccount,
identities_being_verified: IdentitiesBeingVerified, identities_being_verified: IdentitiesBeingVerified,
flow_id: Arc<FlowId>, flow_id: Arc<FlowId>,
we_started: bool,
request_handle: Option<RequestHandle>,
} }
impl Sas { impl Sas {
@ -89,25 +87,72 @@ impl Sas {
&self.flow_id &self.flow_id
} }
/// Get the room id if the verification is happening inside a room.
pub fn room_id(&self) -> Option<&RoomId> {
if let FlowId::InRoom(r, _) = self.flow_id() {
Some(r)
} else {
None
}
}
/// Does this verification flow support displaying emoji for the short /// Does this verification flow support displaying emoji for the short
/// authentication string. /// authentication string.
pub fn supports_emoji(&self) -> bool { pub fn supports_emoji(&self) -> bool {
self.inner.lock().unwrap().supports_emoji() self.inner.lock().unwrap().supports_emoji()
} }
/// Did this verification flow start from a verification request.
pub fn started_from_request(&self) -> bool {
self.inner.lock().unwrap().started_from_request()
}
/// Is this a verification that is veryfying one of our own devices.
pub fn is_self_verification(&self) -> bool {
self.identities_being_verified.is_self_verification()
}
/// Have we confirmed that the short auth string matches.
pub fn have_we_confirmed(&self) -> bool {
self.inner.lock().unwrap().have_we_confirmed()
}
/// Has the verification been accepted by both parties.
pub fn has_been_accepted(&self) -> bool {
self.inner.lock().unwrap().has_been_accepted()
}
/// Get info about the cancellation if the verification flow has been
/// cancelled.
pub fn cancel_info(&self) -> Option<CancelInfo> {
if let InnerSas::Cancelled(c) = &*self.inner.lock().unwrap() {
Some(c.state.as_ref().clone().into())
} else {
None
}
}
/// Did we initiate the verification flow.
pub fn we_started(&self) -> bool {
self.we_started
}
#[cfg(test)] #[cfg(test)]
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn set_creation_time(&self, time: Instant) { pub(crate) fn set_creation_time(&self, time: Instant) {
self.inner.lock().unwrap().set_creation_time(time) self.inner.lock().unwrap().set_creation_time(time)
} }
#[allow(clippy::too_many_arguments)]
fn start_helper( fn start_helper(
inner_sas: InnerSas, inner_sas: InnerSas,
account: ReadOnlyAccount, account: ReadOnlyAccount,
private_identity: PrivateCrossSigningIdentity, private_identity: PrivateCrossSigningIdentity,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
store: Arc<dyn CryptoStore>, store: Arc<dyn CryptoStore>,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
we_started: bool,
request_handle: Option<RequestHandle>,
) -> Sas { ) -> Sas {
let flow_id = inner_sas.verification_flow_id(); let flow_id = inner_sas.verification_flow_id();
@ -123,6 +168,8 @@ impl Sas {
account, account,
identities_being_verified: identities, identities_being_verified: identities,
flow_id, flow_id,
we_started,
request_handle,
} }
} }
@ -136,13 +183,16 @@ impl Sas {
/// ///
/// Returns the new `Sas` object and a `StartEventContent` that needs to be /// Returns the new `Sas` object and a `StartEventContent` that needs to be
/// sent out through the server to the other device. /// sent out through the server to the other device.
#[allow(clippy::too_many_arguments)]
pub(crate) fn start( pub(crate) fn start(
account: ReadOnlyAccount, account: ReadOnlyAccount,
private_identity: PrivateCrossSigningIdentity, private_identity: PrivateCrossSigningIdentity,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
store: Arc<dyn CryptoStore>, store: Arc<dyn CryptoStore>,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
transaction_id: Option<String>, transaction_id: Option<String>,
we_started: bool,
request_handle: Option<RequestHandle>,
) -> (Sas, OutgoingContent) { ) -> (Sas, OutgoingContent) {
let (inner, content) = InnerSas::start( let (inner, content) = InnerSas::start(
account.clone(), account.clone(),
@ -159,6 +209,8 @@ impl Sas {
other_device, other_device,
store, store,
other_identity, other_identity,
we_started,
request_handle,
), ),
content, content,
) )
@ -174,6 +226,7 @@ impl Sas {
/// ///
/// Returns the new `Sas` object and a `StartEventContent` that needs to be /// Returns the new `Sas` object and a `StartEventContent` that needs to be
/// sent out through the server to the other device. /// sent out through the server to the other device.
#[allow(clippy::too_many_arguments)]
pub(crate) fn start_in_room( pub(crate) fn start_in_room(
flow_id: EventId, flow_id: EventId,
room_id: RoomId, room_id: RoomId,
@ -181,7 +234,9 @@ impl Sas {
private_identity: PrivateCrossSigningIdentity, private_identity: PrivateCrossSigningIdentity,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
store: Arc<dyn CryptoStore>, store: Arc<dyn CryptoStore>,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
we_started: bool,
request_handle: RequestHandle,
) -> (Sas, OutgoingContent) { ) -> (Sas, OutgoingContent) {
let (inner, content) = InnerSas::start_in_room( let (inner, content) = InnerSas::start_in_room(
flow_id, flow_id,
@ -199,6 +254,8 @@ impl Sas {
other_device, other_device,
store, store,
other_identity, other_identity,
we_started,
Some(request_handle),
), ),
content, content,
) )
@ -222,8 +279,9 @@ impl Sas {
account: ReadOnlyAccount, account: ReadOnlyAccount,
private_identity: PrivateCrossSigningIdentity, private_identity: PrivateCrossSigningIdentity,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
started_from_request: bool, request_handle: Option<RequestHandle>,
we_started: bool,
) -> Result<Sas, OutgoingContent> { ) -> Result<Sas, OutgoingContent> {
let inner = InnerSas::from_start_event( let inner = InnerSas::from_start_event(
account.clone(), account.clone(),
@ -231,7 +289,7 @@ impl Sas {
flow_id, flow_id,
content, content,
other_identity.clone(), other_identity.clone(),
started_from_request, request_handle.is_some(),
)?; )?;
Ok(Self::start_helper( Ok(Self::start_helper(
@ -241,6 +299,8 @@ impl Sas {
other_device, other_device,
store, store,
other_identity, other_identity,
we_started,
request_handle,
)) ))
} }
@ -262,18 +322,28 @@ impl Sas {
&self, &self,
settings: AcceptSettings, settings: AcceptSettings,
) -> Option<OutgoingVerificationRequest> { ) -> Option<OutgoingVerificationRequest> {
self.inner.lock().unwrap().accept().map(|c| match settings.apply(c) { let mut guard = self.inner.lock().unwrap();
OwnedAcceptContent::ToDevice(c) => { let sas: InnerSas = (*guard).clone();
let content = AnyToDeviceEventContent::KeyVerificationAccept(c); let methods = settings.allowed_methods;
self.content_to_request(content).into()
} if let Some((sas, content)) = sas.accept(methods) {
OwnedAcceptContent::Room(room_id, content) => RoomMessageRequest { *guard = sas;
room_id,
txn_id: Uuid::new_v4(), Some(match content {
content: AnyMessageEventContent::KeyVerificationAccept(content), OwnedAcceptContent::ToDevice(c) => {
} let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
.into(), self.content_to_request(content).into()
}) }
OwnedAcceptContent::Room(room_id, content) => RoomMessageRequest {
room_id,
txn_id: Uuid::new_v4(),
content: AnyMessageEventContent::KeyVerificationAccept(content),
}
.into(),
})
} else {
None
}
} }
/// Confirm the Sas verification. /// Confirm the Sas verification.
@ -340,10 +410,28 @@ impl Sas {
self.cancel_with_code(CancelCode::User) self.cancel_with_code(CancelCode::User)
} }
pub(crate) fn cancel_with_code(&self, code: CancelCode) -> Option<OutgoingVerificationRequest> { /// Cancel the verification.
///
/// This cancels the verification with given `CancelCode`.
///
/// **Note**: This method should generally not be used, the [`cancel()`]
/// method should be preferred. The SDK will automatically cancel with the
/// approprate cancel code, user initiated cancellations should only cancel
/// with the `CancelCode::User`
///
/// Returns None if the `Sas` object is already in a canceled state,
/// otherwise it returns a request that needs to be sent out.
///
/// [`cancel()`]: #method.cancel
pub fn cancel_with_code(&self, code: CancelCode) -> Option<OutgoingVerificationRequest> {
let mut guard = self.inner.lock().unwrap(); let mut guard = self.inner.lock().unwrap();
if let Some(request) = &self.request_handle {
request.cancel_with_code(&code)
}
let sas: InnerSas = (*guard).clone(); let sas: InnerSas = (*guard).clone();
let (sas, content) = sas.cancel(code); let (sas, content) = sas.cancel(true, code);
*guard = sas; *guard = sas;
content.map(|c| match c { content.map(|c| match c {
OutgoingContent::Room(room_id, content) => { OutgoingContent::Room(room_id, content) => {
@ -427,12 +515,12 @@ impl Sas {
self.inner.lock().unwrap().verified_devices() self.inner.lock().unwrap().verified_devices()
} }
pub(crate) fn verified_identities(&self) -> Option<Arc<[UserIdentities]>> { pub(crate) fn verified_identities(&self) -> Option<Arc<[ReadOnlyUserIdentities]>> {
self.inner.lock().unwrap().verified_identities() self.inner.lock().unwrap().verified_identities()
} }
pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest { pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest {
content_to_request(self.other_user_id(), self.other_device_id().to_owned(), content) ToDeviceRequest::new(self.other_user_id(), self.other_device_id().to_owned(), content)
} }
} }
@ -463,23 +551,6 @@ impl AcceptSettings {
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self { pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
Self { allowed_methods: methods } Self { allowed_methods: methods }
} }
fn apply(self, mut content: OwnedAcceptContent) -> OwnedAcceptContent {
match &mut content {
OwnedAcceptContent::ToDevice(AcceptToDeviceEventContent {
method: AcceptMethod::MSasV1(c),
..
})
| OwnedAcceptContent::Room(
_,
AcceptEventContent { method: AcceptMethod::MSasV1(c), .. },
) => {
c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
content
}
_ => content,
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -536,6 +607,8 @@ mod test {
alice_store, alice_store,
None, None,
None, None,
true,
None,
); );
let flow_id = alice.flow_id().to_owned(); let flow_id = alice.flow_id().to_owned();
@ -549,6 +622,7 @@ mod test {
PrivateCrossSigningIdentity::empty(bob_id()), PrivateCrossSigningIdentity::empty(bob_id()),
alice_device, alice_device,
None, None,
None,
false, false,
) )
.unwrap(); .unwrap();

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use std::{ use std::{
convert::TryFrom, convert::{TryFrom, TryInto},
matches, matches,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
time::{Duration, Instant}, time::{Duration, Instant},
@ -52,7 +52,7 @@ use super::{
OutgoingContent, OutgoingContent,
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
verification::{ verification::{
event_enums::{ event_enums::{
AcceptContent, DoneContent, KeyContent, MacContent, OwnedAcceptContent, AcceptContent, DoneContent, KeyContent, MacContent, OwnedAcceptContent,
@ -102,7 +102,7 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
Err(CancelCode::UnknownMethod) Err(CancelCode::UnknownMethod)
} else { } else {
Ok(Self { Ok(Self {
method: VerificationMethod::MSasV1, method: VerificationMethod::SasV1,
hash: content.hash, hash: content.hash,
key_agreement_protocol: content.key_agreement_protocol, key_agreement_protocol: content.key_agreement_protocol,
message_auth_code: content.message_authentication_code, message_auth_code: content.message_authentication_code,
@ -149,7 +149,7 @@ impl TryFrom<&SasV1Content> for AcceptedProtocols {
} }
Ok(Self { Ok(Self {
method: VerificationMethod::MSasV1, method: VerificationMethod::SasV1,
hash: HashAlgorithm::Sha256, hash: HashAlgorithm::Sha256,
key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256, key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
message_auth_code: MessageAuthenticationCode::HkdfHmacSha256, message_auth_code: MessageAuthenticationCode::HkdfHmacSha256,
@ -163,7 +163,7 @@ impl TryFrom<&SasV1Content> for AcceptedProtocols {
impl Default for AcceptedProtocols { impl Default for AcceptedProtocols {
fn default() -> Self { fn default() -> Self {
AcceptedProtocols { AcceptedProtocols {
method: VerificationMethod::MSasV1, method: VerificationMethod::SasV1,
hash: HashAlgorithm::Sha256, hash: HashAlgorithm::Sha256,
key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256, key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
message_auth_code: MessageAuthenticationCode::HkdfHmacSha256, message_auth_code: MessageAuthenticationCode::HkdfHmacSha256,
@ -222,7 +222,7 @@ impl<S: Clone + std::fmt::Debug> std::fmt::Debug for SasState<S> {
/// The initial SAS state. /// The initial SAS state.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Created { pub struct Created {
protocol_definitions: SasV1ContentInit, protocol_definitions: SasV1Content,
} }
/// The initial SAS state if the other side started the SAS verification. /// The initial SAS state if the other side started the SAS verification.
@ -241,6 +241,15 @@ pub struct Accepted {
commitment: String, commitment: String,
} }
/// The SAS state we're going to be in after we accepted our
/// verification start event.
#[derive(Clone, Debug)]
pub struct WeAccepted {
we_started: bool,
pub accepted_protocols: Arc<AcceptedProtocols>,
commitment: String,
}
/// The SAS state we're going to be in after we received the public key of the /// The SAS state we're going to be in after we received the public key of the
/// other participant. /// other participant.
/// ///
@ -268,7 +277,7 @@ pub struct MacReceived {
we_started: bool, we_started: bool,
their_pubkey: String, their_pubkey: String,
verified_devices: Arc<[ReadOnlyDevice]>, verified_devices: Arc<[ReadOnlyDevice]>,
verified_master_keys: Arc<[UserIdentities]>, verified_master_keys: Arc<[ReadOnlyUserIdentities]>,
pub accepted_protocols: Arc<AcceptedProtocols>, pub accepted_protocols: Arc<AcceptedProtocols>,
} }
@ -278,7 +287,7 @@ pub struct MacReceived {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct WaitingForDone { pub struct WaitingForDone {
verified_devices: Arc<[ReadOnlyDevice]>, verified_devices: Arc<[ReadOnlyDevice]>,
verified_master_keys: Arc<[UserIdentities]>, verified_master_keys: Arc<[ReadOnlyUserIdentities]>,
} }
impl<S: Clone> SasState<S> { impl<S: Clone> SasState<S> {
@ -298,14 +307,14 @@ impl<S: Clone> SasState<S> {
self.ids.other_device.clone() self.ids.other_device.clone()
} }
pub fn cancel(self, cancel_code: CancelCode) -> SasState<Cancelled> { pub fn cancel(self, cancelled_by_us: bool, cancel_code: CancelCode) -> SasState<Cancelled> {
SasState { SasState {
inner: self.inner, inner: self.inner,
ids: self.ids, ids: self.ids,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: self.last_event_time,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
state: Arc::new(Cancelled::new(cancel_code)), state: Arc::new(Cancelled::new(cancelled_by_us, cancel_code)),
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
} }
} }
@ -353,7 +362,7 @@ impl SasState<Created> {
pub fn new( pub fn new(
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
transaction_id: Option<String>, transaction_id: Option<String>,
) -> SasState<Created> { ) -> SasState<Created> {
let started_from_request = transaction_id.is_some(); let started_from_request = transaction_id.is_some();
@ -379,7 +388,7 @@ impl SasState<Created> {
event_id: EventId, event_id: EventId,
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
) -> SasState<Created> { ) -> SasState<Created> {
let flow_id = FlowId::InRoom(room_id, event_id); let flow_id = FlowId::InRoom(room_id, event_id);
Self::new_helper(flow_id, account, other_device, other_identity, false) Self::new_helper(flow_id, account, other_device, other_identity, false)
@ -389,7 +398,7 @@ impl SasState<Created> {
flow_id: FlowId, flow_id: FlowId,
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
started_from_request: bool, started_from_request: bool,
) -> SasState<Created> { ) -> SasState<Created> {
SasState { SasState {
@ -407,7 +416,9 @@ impl SasState<Created> {
key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(), key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(),
message_authentication_codes: MACS.to_vec(), message_authentication_codes: MACS.to_vec(),
hashes: HASHES.to_vec(), hashes: HASHES.to_vec(),
}, }
.try_into()
.expect("Invalid protocol definition."),
}), }),
} }
} }
@ -417,19 +428,13 @@ impl SasState<Created> {
FlowId::ToDevice(s) => OwnedStartContent::ToDevice(StartToDeviceEventContent::new( FlowId::ToDevice(s) => OwnedStartContent::ToDevice(StartToDeviceEventContent::new(
self.device_id().into(), self.device_id().into(),
s.to_string(), s.to_string(),
StartMethod::SasV1( StartMethod::SasV1(self.state.protocol_definitions.clone()),
SasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."),
),
)), )),
FlowId::InRoom(r, e) => OwnedStartContent::Room( FlowId::InRoom(r, e) => OwnedStartContent::Room(
r.clone(), r.clone(),
StartEventContent::new( StartEventContent::new(
self.device_id().into(), self.device_id().into(),
StartMethod::SasV1( StartMethod::SasV1(self.state.protocol_definitions.clone()),
SasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."),
),
Relation::new(e.clone()), Relation::new(e.clone()),
), ),
), ),
@ -448,11 +453,11 @@ impl SasState<Created> {
sender: &UserId, sender: &UserId,
content: &AcceptContent, content: &AcceptContent,
) -> Result<SasState<Accepted>, SasState<Cancelled>> { ) -> Result<SasState<Accepted>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(c))?; self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
if let AcceptMethod::MSasV1(content) = content.method() { if let AcceptMethod::SasV1(content) = content.method() {
let accepted_protocols = let accepted_protocols = AcceptedProtocols::try_from(content.clone())
AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(true, c))?;
let start_content = self.as_content().into(); let start_content = self.as_content().into();
@ -461,7 +466,7 @@ impl SasState<Created> {
ids: self.ids, ids: self.ids,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: Instant::now().into(),
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
state: Arc::new(Accepted { state: Arc::new(Accepted {
start_content, start_content,
@ -470,7 +475,7 @@ impl SasState<Created> {
}), }),
}) })
} else { } else {
Err(self.cancel(CancelCode::UnknownMethod)) Err(self.cancel(true, CancelCode::UnknownMethod))
} }
} }
} }
@ -492,7 +497,7 @@ impl SasState<Started> {
pub fn from_start_event( pub fn from_start_event(
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<ReadOnlyUserIdentities>,
flow_id: FlowId, flow_id: FlowId,
content: &StartContent, content: &StartContent,
started_from_request: bool, started_from_request: bool,
@ -513,7 +518,7 @@ impl SasState<Started> {
}, },
verification_flow_id: flow_id.clone(), verification_flow_id: flow_id.clone(),
state: Arc::new(Cancelled::new(CancelCode::UnknownMethod)), state: Arc::new(Cancelled::new(true, CancelCode::UnknownMethod)),
}; };
if let StartMethod::SasV1(method_content) = content.method() { if let StartMethod::SasV1(method_content) = content.method() {
@ -552,6 +557,32 @@ impl SasState<Started> {
} }
} }
pub fn into_accepted(self, methods: Vec<ShortAuthenticationString>) -> SasState<WeAccepted> {
let mut accepted_protocols = self.state.accepted_protocols.as_ref().to_owned();
accepted_protocols.short_auth_string = methods;
// Decimal is required per spec.
if !accepted_protocols.short_auth_string.contains(&ShortAuthenticationString::Decimal) {
accepted_protocols.short_auth_string.push(ShortAuthenticationString::Decimal);
}
SasState {
inner: self.inner,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time,
last_event_time: self.last_event_time,
started_from_request: self.started_from_request,
state: Arc::new(WeAccepted {
we_started: false,
accepted_protocols: accepted_protocols.into(),
commitment: self.state.commitment.clone(),
}),
}
}
}
impl SasState<WeAccepted> {
/// Get the content for the accept event. /// Get the content for the accept event.
/// ///
/// The content needs to be sent to the other device. /// The content needs to be sent to the other device.
@ -560,7 +591,7 @@ impl SasState<Started> {
/// been started because of a /// been started because of a
/// m.key.verification.request -> m.key.verification.ready flow. /// m.key.verification.request -> m.key.verification.ready flow.
pub fn as_content(&self) -> OwnedAcceptContent { pub fn as_content(&self) -> OwnedAcceptContent {
let method = AcceptMethod::MSasV1( let method = AcceptMethod::SasV1(
AcceptV1ContentInit { AcceptV1ContentInit {
commitment: self.state.commitment.clone(), commitment: self.state.commitment.clone(),
hash: self.state.accepted_protocols.hash.clone(), hash: self.state.accepted_protocols.hash.clone(),
@ -604,7 +635,7 @@ impl SasState<Started> {
sender: &UserId, sender: &UserId,
content: &KeyContent, content: &KeyContent,
) -> Result<SasState<KeyReceived>, SasState<Cancelled>> { ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(c))?; self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let their_pubkey = content.public_key().to_owned(); let their_pubkey = content.public_key().to_owned();
@ -619,7 +650,7 @@ impl SasState<Started> {
ids: self.ids, ids: self.ids,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: Instant::now().into(),
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
state: Arc::new(KeyReceived { state: Arc::new(KeyReceived {
we_started: false, we_started: false,
@ -644,7 +675,7 @@ impl SasState<Accepted> {
sender: &UserId, sender: &UserId,
content: &KeyContent, content: &KeyContent,
) -> Result<SasState<KeyReceived>, SasState<Cancelled>> { ) -> Result<SasState<KeyReceived>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(c))?; self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let commitment = calculate_commitment( let commitment = calculate_commitment(
content.public_key(), content.public_key(),
@ -652,7 +683,7 @@ impl SasState<Accepted> {
); );
if self.state.commitment != commitment { if self.state.commitment != commitment {
Err(self.cancel(CancelCode::InvalidMessage)) Err(self.cancel(true, CancelCode::InvalidMessage))
} else { } else {
let their_pubkey = content.public_key().to_owned(); let their_pubkey = content.public_key().to_owned();
@ -667,7 +698,7 @@ impl SasState<Accepted> {
ids: self.ids, ids: self.ids,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: Instant::now().into(),
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
state: Arc::new(KeyReceived { state: Arc::new(KeyReceived {
their_pubkey, their_pubkey,
@ -684,10 +715,10 @@ impl SasState<Accepted> {
pub fn as_content(&self) -> OutgoingContent { pub fn as_content(&self) -> OutgoingContent {
match &*self.verification_flow_id { match &*self.verification_flow_id {
FlowId::ToDevice(s) => { FlowId::ToDevice(s) => {
AnyToDeviceEventContent::KeyVerificationKey(KeyToDeviceEventContent { AnyToDeviceEventContent::KeyVerificationKey(KeyToDeviceEventContent::new(
transaction_id: s.to_string(), s.to_string(),
key: self.inner.lock().unwrap().public_key(), self.inner.lock().unwrap().public_key(),
}) ))
.into() .into()
} }
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
@ -710,10 +741,10 @@ impl SasState<KeyReceived> {
pub fn as_content(&self) -> OutgoingContent { pub fn as_content(&self) -> OutgoingContent {
match &*self.verification_flow_id { match &*self.verification_flow_id {
FlowId::ToDevice(s) => { FlowId::ToDevice(s) => {
AnyToDeviceEventContent::KeyVerificationKey(KeyToDeviceEventContent { AnyToDeviceEventContent::KeyVerificationKey(KeyToDeviceEventContent::new(
transaction_id: s.to_string(), s.to_string(),
key: self.inner.lock().unwrap().public_key(), self.inner.lock().unwrap().public_key(),
}) ))
.into() .into()
} }
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
@ -781,7 +812,7 @@ impl SasState<KeyReceived> {
sender: &UserId, sender: &UserId,
content: &MacContent, content: &MacContent,
) -> Result<SasState<MacReceived>, SasState<Cancelled>> { ) -> Result<SasState<MacReceived>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(c))?; self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let (devices, master_keys) = receive_mac_event( let (devices, master_keys) = receive_mac_event(
&self.inner.lock().unwrap(), &self.inner.lock().unwrap(),
@ -790,13 +821,13 @@ impl SasState<KeyReceived> {
sender, sender,
content, content,
) )
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(true, c))?;
Ok(SasState { Ok(SasState {
inner: self.inner, inner: self.inner,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: Instant::now().into(),
ids: self.ids, ids: self.ids,
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
state: Arc::new(MacReceived { state: Arc::new(MacReceived {
@ -841,7 +872,7 @@ impl SasState<Confirmed> {
sender: &UserId, sender: &UserId,
content: &MacContent, content: &MacContent,
) -> Result<SasState<Done>, SasState<Cancelled>> { ) -> Result<SasState<Done>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(c))?; self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let (devices, master_keys) = receive_mac_event( let (devices, master_keys) = receive_mac_event(
&self.inner.lock().unwrap(), &self.inner.lock().unwrap(),
@ -850,12 +881,12 @@ impl SasState<Confirmed> {
sender, sender,
content, content,
) )
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(true, c))?;
Ok(SasState { Ok(SasState {
inner: self.inner, inner: self.inner,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: Instant::now().into(),
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
ids: self.ids, ids: self.ids,
@ -881,7 +912,7 @@ impl SasState<Confirmed> {
sender: &UserId, sender: &UserId,
content: &MacContent, content: &MacContent,
) -> Result<SasState<WaitingForDone>, SasState<Cancelled>> { ) -> Result<SasState<WaitingForDone>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(c))?; self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let (devices, master_keys) = receive_mac_event( let (devices, master_keys) = receive_mac_event(
&self.inner.lock().unwrap(), &self.inner.lock().unwrap(),
@ -890,12 +921,12 @@ impl SasState<Confirmed> {
sender, sender,
content, content,
) )
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(true, c))?;
Ok(SasState { Ok(SasState {
inner: self.inner, inner: self.inner,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: Instant::now().into(),
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
ids: self.ids, ids: self.ids,
@ -1036,12 +1067,12 @@ impl SasState<WaitingForDone> {
sender: &UserId, sender: &UserId,
content: &DoneContent, content: &DoneContent,
) -> Result<SasState<Done>, SasState<Cancelled>> { ) -> Result<SasState<Done>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(c))?; self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
Ok(SasState { Ok(SasState {
inner: self.inner, inner: self.inner,
creation_time: self.creation_time, creation_time: self.creation_time,
last_event_time: self.last_event_time, last_event_time: Instant::now().into(),
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
ids: self.ids, ids: self.ids,
@ -1073,7 +1104,7 @@ impl SasState<Done> {
} }
/// Get the list of verified identities. /// Get the list of verified identities.
pub fn verified_identities(&self) -> Arc<[UserIdentities]> { pub fn verified_identities(&self) -> Arc<[ReadOnlyUserIdentities]> {
self.state.verified_master_keys.clone() self.state.verified_master_keys.clone()
} }
} }
@ -1090,13 +1121,15 @@ mod test {
use ruma::{ use ruma::{
events::key::verification::{ events::key::verification::{
accept::{AcceptMethod, CustomContent}, accept::{AcceptMethod, AcceptToDeviceEventContent},
start::{CustomContent as CustomStartContent, StartMethod}, start::{StartMethod, StartToDeviceEventContent},
ShortAuthenticationString,
}, },
DeviceId, UserId, DeviceId, UserId,
}; };
use serde_json::json;
use super::{Accepted, Created, SasState, Started}; use super::{Accepted, Created, SasState, Started, WeAccepted};
use crate::{ use crate::{
verification::event_enums::{AcceptContent, KeyContent, MacContent, StartContent}, verification::event_enums::{AcceptContent, KeyContent, MacContent, StartContent},
ReadOnlyAccount, ReadOnlyDevice, ReadOnlyAccount, ReadOnlyDevice,
@ -1118,7 +1151,7 @@ mod test {
"BOBDEVCIE".into() "BOBDEVCIE".into()
} }
async fn get_sas_pair() -> (SasState<Created>, SasState<Started>) { async fn get_sas_pair() -> (SasState<Created>, SasState<WeAccepted>) {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let alice_device = ReadOnlyDevice::from_account(&alice).await; let alice_device = ReadOnlyDevice::from_account(&alice).await;
@ -1138,8 +1171,9 @@ mod test {
&start_content.as_start_content(), &start_content.as_start_content(),
false, false,
); );
let bob_sas = bob_sas.unwrap().into_accepted(vec![ShortAuthenticationString::Emoji]);
(alice_sas, bob_sas.unwrap()) (alice_sas, bob_sas)
} }
#[tokio::test] #[tokio::test]
@ -1227,7 +1261,7 @@ mod test {
let mut method = content.method_mut(); let mut method = content.method_mut();
match &mut method { match &mut method {
AcceptMethod::MSasV1(ref mut c) => { AcceptMethod::SasV1(ref mut c) => {
c.commitment = "".to_string(); c.commitment = "".to_string();
} }
_ => panic!("Unknown accept event content"), _ => panic!("Unknown accept event content"),
@ -1266,7 +1300,7 @@ mod test {
let mut method = content.method_mut(); let mut method = content.method_mut();
match &mut method { match &mut method {
AcceptMethod::MSasV1(ref mut c) => { AcceptMethod::SasV1(ref mut c) => {
c.short_authentication_string = vec![]; c.short_authentication_string = vec![];
} }
_ => panic!("Unknown accept event content"), _ => panic!("Unknown accept event content"),
@ -1283,14 +1317,13 @@ mod test {
async fn sas_unknown_method() { async fn sas_unknown_method() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let mut content = bob.as_content(); let content = json!({
let method = content.method_mut(); "method": "m.sas.custom",
"method_data": "something",
*method = AcceptMethod::Custom(CustomContent { "transaction_id": "some_id",
method: "m.sas.custom".to_string(),
data: Default::default(),
}); });
let content: AcceptToDeviceEventContent = serde_json::from_value(content).unwrap();
let content = AcceptContent::from(&content); let content = AcceptContent::from(&content);
alice alice
@ -1331,22 +1364,22 @@ mod test {
) )
.expect_err("Didn't cancel on invalid MAC method"); .expect_err("Didn't cancel on invalid MAC method");
let mut start_content = alice_sas.as_content(); let content = json!({
let method = start_content.method_mut(); "method": "m.sas.custom",
"from_device": "DEVICEID",
*method = StartMethod::Custom(CustomStartContent { "method_data": "something",
method: "m.sas.custom".to_string(), "transaction_id": "some_id",
data: Default::default(),
}); });
let flow_id = start_content.flow_id(); let content: StartToDeviceEventContent = serde_json::from_value(content).unwrap();
let content = StartContent::from(&start_content); let content = StartContent::from(&content);
let flow_id = content.flow_id().to_owned();
SasState::<Started>::from_start_event( SasState::<Started>::from_start_event(
bob.clone(), bob.clone(),
alice_device, alice_device,
None, None,
flow_id, flow_id.into(),
&content, &content,
false, false,
) )

View File

@ -8,16 +8,16 @@ license = "Apache-2.0"
name = "matrix-sdk-test" name = "matrix-sdk-test"
readme = "README.md" readme = "README.md"
repository = "https://github.com/matrix-org/matrix-rust-sdk" repository = "https://github.com/matrix-org/matrix-rust-sdk"
version = "0.2.0" version = "0.3.0"
[features] [features]
appservice = [] appservice = []
[dependencies] [dependencies]
http = "0.2.3" http = "0.2.4"
lazy_static = "1.4.0" lazy_static = "1.4.0"
matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" } matrix-sdk-common = { version = "0.3.0", path = "../matrix_sdk_common" }
matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" } matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" }
ruma = { version = "0.1.2", features = ["client-api-c"] } ruma = { version = "0.2.0", features = ["client-api-c"] }
serde = "1.0.122" serde = "1.0.126"
serde_json = "1.0.61" serde_json = "1.0.64"

View File

@ -1,6 +1,6 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use ruma::{events::AnyRoomEvent, identifiers::room_id}; use ruma::{events::AnyRoomEvent, room_id};
use serde_json::Value; use serde_json::Value;
use crate::{test_json, EventsJson}; use crate::{test_json, EventsJson};