Add use_small_heuristics option and run fmt

This commit is contained in:
Devin Ragotzy 2021-05-12 11:00:47 -04:00
parent c85f4d4f0c
commit 2ef0c2959c
69 changed files with 1563 additions and 3660 deletions

View file

@ -72,15 +72,11 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client
.login(username, password, None, Some("autojoin bot"))
.await?;
client.login(username, password, None, Some("autojoin bot")).await?;
println!("logged in as {}", username);
client
.set_event_handler(Box::new(AutoJoinBot::new(client.clone())))
.await;
client.set_event_handler(Box::new(AutoJoinBot::new(client.clone()))).await;
client.sync(SyncSettings::default()).await;

View file

@ -69,9 +69,7 @@ async fn login_and_sync(
// create a new Client with the given homeserver url and config
let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client
.login(&username, &password, None, Some("command bot"))
.await?;
client.login(&username, &password, None, Some("command bot")).await?;
println!("logged in as {}", username);

View file

@ -21,11 +21,7 @@ fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> Aut
auth_parameters.insert("identifier".to_owned(), identifier);
auth_parameters.insert("password".to_owned(), password.to_owned().into());
AuthData::DirectRequest {
kind: "m.login.password",
auth_parameters,
session,
}
AuthData::DirectRequest { kind: "m.login.password", auth_parameters, session }
}
async fn bootstrap(client: Client, user_id: UserId, password: String) {
@ -33,9 +29,7 @@ async fn bootstrap(client: Client, user_id: UserId, password: String) {
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.expect("error: unable to read user input");
io::stdin().read_line(&mut input).expect("error: unable to read user input");
#[cfg(feature = "encryption")]
if let Err(e) = client.bootstrap_cross_signing(None).await {
@ -62,9 +56,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
let response = client
.login(username, password, None, Some("rust-sdk"))
.await?;
let response = client.login(username, password, None, Some("rust-sdk")).await?;
let user_id = &response.user_id;
let client_ref = &client;

View file

@ -19,9 +19,7 @@ async fn wait_for_confirmation(client: Client, sas: Sas) {
println!("Does the emoji match: {:?}", sas.emoji());
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.expect("error: unable to read user input");
io::stdin().read_line(&mut input).expect("error: unable to read user input");
match input.trim().to_lowercase().as_ref() {
"yes" | "true" | "ok" => {
@ -68,9 +66,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
client
.login(username, password, None, Some("rust-sdk"))
.await?;
client.login(username, password, None, Some("rust-sdk")).await?;
let client_ref = &client;
let initial_sync = Arc::new(AtomicBool::from(true));
@ -81,12 +77,7 @@ async fn login(
let client = &client_ref;
let initial = &initial_ref;
for event in response
.to_device
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in response.to_device.events.iter().filter_map(|e| e.deserialize().ok()) {
match event {
AnyToDeviceEvent::KeyVerificationStart(e) => {
let sas = client
@ -129,11 +120,8 @@ async fn login(
if !initial.load(Ordering::SeqCst) {
for (_room_id, room_info) in response.rooms.join {
for event in room_info
.timeline
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
for event in
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
{
if let AnySyncRoomEvent::Message(event) = event {
match event {

View file

@ -28,10 +28,7 @@ async fn get_profile(client: Client, mxid: &UserId) -> MatrixResult<UserProfile>
// Use the response and construct a UserProfile struct.
// See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html
// for details on the Response for this Request
let user_profile = UserProfile {
avatar_url: resp.avatar_url,
displayname: resp.displayname,
};
let user_profile = UserProfile { avatar_url: resp.avatar_url, displayname: resp.displayname };
Ok(user_profile)
}
@ -43,9 +40,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
client
.login(username, password, None, Some("rust-sdk"))
.await?;
client.login(username, password, None, Some("rust-sdk")).await?;
Ok(client)
}

View file

@ -52,9 +52,7 @@ impl EventHandler for ImageBot {
println!("sending image");
let mut image = self.image.lock().await;
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None)
.await
.unwrap();
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap();
image.seek(SeekFrom::Start(0)).unwrap();
@ -73,14 +71,10 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
client
.login(&username, &password, None, Some("command bot"))
.await?;
client.login(&username, &password, None, Some("command bot")).await?;
client.sync_once(SyncSettings::default()).await.unwrap();
client
.set_event_handler(Box::new(ImageBot::new(image)))
.await;
client.set_event_handler(Box::new(ImageBot::new(image))).await;
let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
client.sync(settings).await;
@ -91,26 +85,19 @@ async fn login_and_sync(
#[tokio::main]
async fn main() -> Result<(), matrix_sdk::Error> {
tracing_subscriber::fmt::init();
let (homeserver_url, username, password, image_path) = match (
env::args().nth(1),
env::args().nth(2),
env::args().nth(3),
env::args().nth(4),
) {
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
_ => {
eprintln!(
"Usage: {} <homeserver_url> <username> <password> <image>",
env::args().next().unwrap()
);
exit(1)
}
};
let (homeserver_url, username, password, image_path) =
match (env::args().nth(1), env::args().nth(2), env::args().nth(3), env::args().nth(4)) {
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
_ => {
eprintln!(
"Usage: {} <homeserver_url> <username> <password> <image>",
env::args().next().unwrap()
);
exit(1)
}
};
println!(
"helloooo {} {} {} {:#?}",
homeserver_url, username, password, image_path
);
println!("helloooo {} {} {} {:#?}", homeserver_url, username, password, image_path);
let path = PathBuf::from(image_path);
let image = File::open(path).expect("Can't open image file.");

View file

@ -28,9 +28,7 @@ impl EventHandler for EventCallback {
} = event
{
let member = room.get_member(&sender).await.unwrap().unwrap();
let name = member
.display_name()
.unwrap_or_else(|| member.user_id().as_str());
let name = member.display_name().unwrap_or_else(|| member.user_id().as_str());
println!("{}: {}", name, msg_body);
}
}
@ -47,9 +45,7 @@ async fn login(
client.set_event_handler(Box::new(EventCallback)).await;
client
.login(username, password, None, Some("rust-sdk"))
.await?;
client.login(username, password, None, Some("rust-sdk")).await?;
client.sync(SyncSettings::new()).await;
Ok(())

File diff suppressed because it is too large Load diff

View file

@ -65,10 +65,7 @@ impl Device {
let (sas, request) = self.inner.start_verification().await?;
self.client.send_to_device(&request).await?;
Ok(Sas {
inner: sas,
client: self.client.clone(),
})
Ok(Sas { inner: sas, client: self.client.clone() })
}
/// Is the device trusted.
@ -102,10 +99,7 @@ pub struct UserDevices {
impl UserDevices {
/// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.inner.get(device_id).map(|d| Device {
inner: d,
client: self.client.clone(),
})
self.inner.get(device_id).map(|d| Device { inner: d, client: self.client.clone() })
}
/// Iterator over all the device ids of the user devices.
@ -117,9 +111,6 @@ impl UserDevices {
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
let client = self.client.clone();
self.inner.devices().map(move |d| Device {
inner: d,
client: client.clone(),
})
self.inner.devices().map(move |d| Device { inner: d, client: client.clone() })
}
}

View file

@ -43,11 +43,13 @@ pub enum HttpError {
#[error(transparent)]
Reqwest(#[from] ReqwestError),
/// Queried endpoint requires authentication but was called on an anonymous client.
/// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired,
/// Client tried to force authentication but did not provide an access token.
/// Client tried to force authentication but did not provide an access
/// token.
#[error("tried to force authentication but no access token was provided")]
ForcedAuthenticationWithoutAccessToken,
@ -69,9 +71,10 @@ pub enum HttpError {
/// An error occurred while authenticating.
///
/// When registering or authenticating the Matrix server can send a `UiaaResponse`
/// as the error type, this is a User-Interactive Authentication API response. This
/// represents an error with information about how to authenticate the user.
/// When registering or authenticating the Matrix server can send a
/// `UiaaResponse` as the error type, this is a User-Interactive
/// Authentication API response. This represents an error with
/// information about how to authenticate the user.
#[error(transparent)]
UiaaError(#[from] FromHttpResponseError<UiaaError>),
@ -96,7 +99,8 @@ pub enum Error {
#[error(transparent)]
Http(#[from] HttpError),
/// Queried endpoint requires authentication but was called on an anonymous client.
/// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired,

View file

@ -87,39 +87,24 @@ impl Handler {
for (room_id, room_info) in &response.rooms.join {
if let Some(room) = self.get_room(room_id) {
for event in room_info
.ephemeral
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in room_info.ephemeral.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_ephemeral_event(room.clone(), &event).await;
}
for event in room_info
.account_data
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_room_account_data_event(room.clone(), &event)
.await;
}
for event in room_info
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
self.handle_state_event(room.clone(), &event).await;
}
for event in room_info
.timeline
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
for event in
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
{
self.handle_timeline_event(room.clone(), &event).await;
}
@ -128,30 +113,19 @@ impl Handler {
for (room_id, room_info) in &response.rooms.leave {
if let Some(room) = self.get_room(room_id) {
for event in room_info
.account_data
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_room_account_data_event(room.clone(), &event)
.await;
}
for event in room_info
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
self.handle_state_event(room.clone(), &event).await;
}
for event in room_info
.timeline
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
for event in
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
{
self.handle_timeline_event(room.clone(), &event).await;
}
@ -160,31 +134,22 @@ impl Handler {
for (room_id, room_info) in &response.rooms.invite {
if let Some(room) = self.get_room(room_id) {
for event in room_info
.invite_state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in
room_info.invite_state.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_stripped_state_event(room.clone(), &event).await;
}
}
}
for event in response
.presence
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in response.presence.events.iter().filter_map(|e| e.deserialize().ok()) {
self.on_presence_event(&event).await;
}
for (room_id, notifications) in &response.notifications {
if let Some(room) = self.get_room(&room_id) {
for notification in notifications {
self.on_room_notification(room.clone(), notification.clone())
.await;
self.on_room_notification(room.clone(), notification.clone()).await;
}
}
}
@ -248,8 +213,7 @@ impl Handler {
self.on_room_tombstone(room, &tomb).await
}
AnySyncStateEvent::Custom(custom) => {
self.on_custom_event(room, &CustomEvent::State(custom))
.await
self.on_custom_event(room, &CustomEvent::State(custom)).await
}
_ => {}
}
@ -267,8 +231,7 @@ impl Handler {
}
AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, &name).await,
AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => {
self.on_stripped_state_canonical_alias(room, &canonical)
.await
self.on_stripped_state_canonical_alias(room, &canonical).await
}
AnyStrippedStateEvent::RoomAliases(aliases) => {
self.on_stripped_state_aliases(room, &aliases).await
@ -340,8 +303,9 @@ pub enum CustomEvent<'c> {
StrippedState(&'c StrippedStateEvent<CustomEventContent>),
}
/// This trait allows any type implementing `EventHandler` to specify event callbacks for each
/// event. The `Client` calls each method when the corresponding event is received.
/// This trait allows any type implementing `EventHandler` to specify event
/// callbacks for each event. The `Client` calls each method when the
/// corresponding event is received.
///
/// # Examples
/// ```
@ -426,8 +390,8 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {}
/// Fires when `Client` receives room events that trigger notifications according to
/// the push rules of the user.
/// Fires when `Client` receives room events that trigger notifications
/// according to the push rules of the user.
async fn on_room_notification(&self, _: Room, _: Notification) {}
// `RoomEvent`s from `IncomingState`
@ -452,7 +416,8 @@ pub trait EventHandler: Send + Sync {
async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
// `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member(
&self,
_: Room,
@ -460,32 +425,38 @@ pub trait EventHandler: Send + Sync {
_: Option<MemberEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event.
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName`
/// event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias(
&self,
_: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases(
&self,
_: Room,
_: &StrippedStateEvent<AliasesEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels(
&self,
_: Room,
_: &StrippedStateEvent<PowerLevelsEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules(
&self,
_: Room,
@ -522,17 +493,18 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event.
async fn on_presence_event(&self, _: &PresenceEvent) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails
/// because the event was unknown to ruma.
/// Fires when `Client` receives a `Event::Custom` event or if
/// deserialization fails because the event was unknown to ruma.
///
/// The only guarantee this method can give about the event is that it is valid JSON.
/// The only guarantee this method can give about the event is that it is
/// valid JSON.
async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails
/// because the event was unknown to ruma.
/// Fires when `Client` receives a `Event::Custom` event or if
/// deserialization fails because the event was unknown to ruma.
///
/// The only guarantee this method can give about the event is that it is in the
/// shape of a valid matrix event.
/// The only guarantee this method can give about the event is that it is in
/// the shape of a valid matrix event.
async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {}
}
@ -640,57 +612,50 @@ mod test {
}
// `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member(
&self,
_: Room,
_: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
self.0
.lock()
.await
.push("stripped state member".to_string())
self.0.lock().await.push("stripped state member".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomName` event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {
self.0.lock().await.push("stripped state name".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias`
/// event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias(
&self,
_: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>,
) {
self.0
.lock()
.await
.push("stripped state canonical".to_string())
self.0.lock().await.push("stripped state canonical".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases(
&self,
_: Room,
_: &StrippedStateEvent<AliasesEventContent>,
) {
self.0
.lock()
.await
.push("stripped state aliases".to_string())
self.0.lock().await.push("stripped state aliases".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(
&self,
_: Room,
_: &StrippedStateEvent<AvatarEventContent>,
) {
self.0
.lock()
.await
.push("stripped state avatar".to_string())
self.0.lock().await.push("stripped state avatar".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels(
&self,
_: Room,
@ -698,7 +663,8 @@ mod test {
) {
self.0.lock().await.push("stripped state power".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules(
&self,
_: Room,
@ -769,14 +735,11 @@ mod test {
}
async fn mock_sync(client: &Client, response: String) {
let _m = mock(
"GET",
Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()),
)
.with_status(200)
.match_header("authorization", "Bearer 1234")
.with_body(response)
.create();
let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()))
.with_status(200)
.match_header("authorization", "Bearer 1234")
.with_body(response)
.create();
let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000));
let _response = client.sync_once(sync_settings).await.unwrap();
@ -824,14 +787,7 @@ mod test {
mock_sync(&client, test_json::INVITE_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"stripped state name",
"stripped state member",
"presence event"
],
)
assert_eq!(v.as_slice(), ["stripped state name", "stripped state member", "presence event"],)
}
#[async_test]
@ -898,15 +854,7 @@ mod test {
mock_sync(&client, test_json::VOIP_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"call invite",
"call answer",
"call candidates",
"call hangup",
],
)
assert_eq!(v.as_slice(), ["call invite", "call answer", "call candidates", "call hangup",],)
}
#[async_test]

View file

@ -121,8 +121,7 @@ impl HttpClient {
let request = if !self.request_config.assert_identity {
self.try_into_http_request(request, session, config).await?
} else {
self.try_into_http_request_with_identy_assertion(request, session, config)
.await?
self.try_into_http_request_with_identy_assertion(request, session, config).await?
};
self.inner.send_request(request, config).await
@ -201,9 +200,7 @@ impl HttpClient {
request: create_content::Request<'_>,
config: Option<RequestConfig>,
) -> Result<create_content::Response, HttpError> {
let response = self
.send_request(request, self.session.clone(), config)
.await?;
let response = self.send_request(request, self.session.clone(), config).await?;
Ok(create_content::Response::try_from_http_response(response)?)
}
@ -216,9 +213,7 @@ impl HttpClient {
Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>,
{
let response = self
.send_request(request, self.session.clone(), config)
.await?;
let response = self.send_request(request, self.session.clone(), config).await?;
trace!("Got response: {:?}", response);
@ -255,9 +250,7 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result<Client, HttpEr
headers.insert(reqwest::header::USER_AGENT, user_agent);
http_client
.default_headers(headers)
.timeout(config.request_config.timeout)
http_client.default_headers(headers).timeout(config.request_config.timeout)
};
#[cfg(target_arch = "wasm32")]
@ -273,9 +266,7 @@ async fn response_to_http_response(
let status = response.status();
let mut http_builder = HttpResponse::builder().status(status);
let headers = http_builder
.headers_mut()
.expect("Can't get the response builder headers");
let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
for (k, v) in response.headers_mut().drain() {
if let Some(key) = k {
@ -285,9 +276,7 @@ async fn response_to_http_response(
let body = response.bytes().await?;
Ok(http_builder
.body(body)
.expect("Can't construct a response using the given body"))
Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
}
#[cfg(any(target_arch = "wasm32"))]
@ -328,18 +317,12 @@ async fn send_request(
};
// Turn errors into permanent errors when the retry limit is reached
let error_type = if stop {
RetryError::Permanent
} else {
RetryError::Transient
};
let error_type = if stop { RetryError::Permanent } else { RetryError::Transient };
let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?;
let response = client
.execute(request)
.await
.map_err(|e| error_type(HttpError::Reqwest(e)))?;
let response =
client.execute(request).await.map_err(|e| error_type(HttpError::Reqwest(e)))?;
let status_code = response.status();
// TODO TOO_MANY_REQUESTS will have a retry timeout which we should

View file

@ -36,10 +36,7 @@ impl Common {
/// * `room` - The underlaying room.
pub fn new(client: Client, room: BaseRoom) -> Self {
// TODO: Make this private
Self {
inner: room,
client,
}
Self { inner: room, client }
}
/// Leave this room.
@ -152,35 +149,25 @@ impl Common {
pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> {
#[allow(clippy::map_clone)]
if let Some(mutex) = self
.client
.members_request_locks
.get(self.inner.room_id())
.map(|m| m.clone())
if let Some(mutex) =
self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
{
mutex.lock().await;
Ok(None)
} else {
let mutex = Arc::new(Mutex::new(()));
self.client
.members_request_locks
.insert(self.inner.room_id().clone(), mutex.clone());
self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await;
let request = get_member_events::Request::new(self.inner.room_id());
let response = self.client.send(request, None).await?;
let response = self
.client
.base_client
.receive_members(self.inner.room_id(), &response)
.await?;
let response =
self.client.base_client.receive_members(self.inner.room_id(), &response).await?;
self.client
.members_request_locks
.remove(self.inner.room_id());
self.client.members_request_locks.remove(self.inner.room_id());
Ok(Some(response))
}

View file

@ -21,9 +21,7 @@ impl Invited {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Invited {
Some(Self {
inner: Common::new(client, room),
})
Some(Self { inner: Common::new(client, room) })
} else {
None
}

View file

@ -47,8 +47,9 @@ const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3);
/// A room in the joined state.
///
/// The `JoinedRoom` contains all methodes specific to a `Room` with type `RoomType::Joined`.
/// Operations may fail once the underlaying `Room` changes `RoomType`.
/// The `JoinedRoom` contains all methodes specific to a `Room` with type
/// `RoomType::Joined`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)]
pub struct Joined {
pub(crate) inner: Common,
@ -63,7 +64,8 @@ impl Deref for Joined {
}
impl Joined {
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type `RoomType::Joined`.
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type
/// `RoomType::Joined`.
///
/// # Arguments
/// * `client` - The client used to make requests.
@ -72,9 +74,7 @@ impl Joined {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Joined {
Some(Self {
inner: Common::new(client, room),
})
Some(Self { inner: Common::new(client, room) })
} else {
None
}
@ -93,9 +93,7 @@ impl Joined {
///
/// * `reason` - The reason for banning this user.
pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), {
reason
});
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), { reason });
self.client.send(request, None).await?;
Ok(())
}
@ -104,13 +102,12 @@ impl Joined {
///
/// # Arguments
///
/// * `user_id` - The `UserId` of the user that should be kicked out of the room.
/// * `user_id` - The `UserId` of the user that should be kicked out of the
/// room.
///
/// * `reason` - Optional reason why the room member is being kicked out.
pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), {
reason
});
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), { reason });
self.client.send(request, None).await?;
Ok(())
}
@ -144,10 +141,11 @@ impl Joined {
/// Activate typing notice for this room.
///
/// The typing notice remains active for 4s. It can be deactivate at any point by setting
/// typing to `false`. If this method is called while the typing notice is active nothing will
/// happen. This method can be called on every key stroke, since it will do nothing while
/// typing is active.
/// The typing notice remains active for 4s. It can be deactivate at any
/// point by setting typing to `false`. If this method is called while
/// the typing notice is active nothing will happen. This method can be
/// called on every key stroke, since it will do nothing while typing is
/// active.
///
/// # Arguments
///
@ -179,21 +177,23 @@ impl Joined {
/// # });
/// ```
pub async fn typing_notice(&self, typing: bool) -> Result<()> {
// Only send a request to the homeserver if the old timeout has elapsed or the typing
// notice changed state within the TYPING_NOTICE_TIMEOUT
// Only send a request to the homeserver if the old timeout has elapsed
// or the typing notice changed state within the
// TYPING_NOTICE_TIMEOUT
let send =
if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) {
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
// We always reactivate the typing notice if typing is true or we may need to
// deactivate it if it's currently active if typing is false
// We always reactivate the typing notice if typing is true or
// we may need to deactivate it if it's
// currently active if typing is false
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
} else {
// Only send a request when we need to deactivate typing
!typing
}
} else {
// Typing notice is currently deactivated, therefore, send a request only when it's
// about to be activated
// Typing notice is currently deactivated, therefore, send a request
// only when it's about to be activated
typing
};
@ -216,11 +216,13 @@ impl Joined {
Ok(())
}
/// Send a request to notify this room that the user has read specific event.
/// Send a request to notify this room that the user has read specific
/// event.
///
/// # Arguments
///
/// * `event_id` - The `EventId` specifies the event to set the read receipt on.
/// * `event_id` - The `EventId` specifies the event to set the read receipt
/// on.
pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> {
let request =
create_receipt::Request::new(self.inner.room_id(), ReceiptType::Read, event_id);
@ -229,22 +231,23 @@ impl Joined {
Ok(())
}
/// Send a request to notify this room that the user has read up to specific event.
/// Send a request to notify this room that the user has read up to specific
/// event.
///
/// # Arguments
///
/// * fully_read - The `EventId` of the event the user has read to.
///
/// * read_receipt - An `EventId` to specify the event to set the read receipt on.
/// * read_receipt - An `EventId` to specify the event to set the read
/// receipt on.
pub async fn read_marker(
&self,
fully_read: &EventId,
read_receipt: Option<&EventId>,
) -> Result<()> {
let request = assign!(
set_read_marker::Request::new(self.inner.room_id(), fully_read),
{ read_receipt }
);
let request = assign!(set_read_marker::Request::new(self.inner.room_id(), fully_read), {
read_receipt
});
self.client.send(request, None).await?;
Ok(())
@ -262,11 +265,8 @@ impl Joined {
// TODO expose this publicly so people can pre-share a group session if
// e.g. a user starts to type a message for a room.
#[allow(clippy::map_clone)]
if let Some(mutex) = self
.client
.group_session_locks
.get(self.inner.room_id())
.map(|m| m.clone())
if let Some(mutex) =
self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
{
// If a group session share request is already going on,
// await the release of the lock.
@ -275,23 +275,14 @@ impl Joined {
// Otherwise create a new lock and share the group
// session.
let mutex = Arc::new(Mutex::new(()));
self.client
.group_session_locks
.insert(self.inner.room_id().clone(), mutex.clone());
self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await;
{
let joined = self
.client
.store()
.get_joined_user_ids(self.inner.room_id())
.await?;
let invited = self
.client
.store()
.get_invited_user_ids(self.inner.room_id())
.await?;
let joined = self.client.store().get_joined_user_ids(self.inner.room_id()).await?;
let invited =
self.client.store().get_invited_user_ids(self.inner.room_id()).await?;
let members = joined.iter().chain(&invited);
self.client.claim_one_time_keys(members).await?;
};
@ -304,10 +295,7 @@ impl Joined {
// session as using it would end up in undecryptable
// messages.
if let Err(r) = response {
self.client
.base_client
.invalidate_group_session(self.inner.room_id())
.await?;
self.client.base_client.invalidate_group_session(self.inner.room_id()).await?;
return Err(r);
}
}
@ -324,19 +312,13 @@ impl Joined {
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[instrument]
async fn share_group_session(&self) -> Result<()> {
let mut requests = self
.client
.base_client
.share_group_session(self.inner.room_id())
.await?;
let mut requests =
self.client.base_client.share_group_session(self.inner.room_id()).await?;
for request in requests.drain(..) {
let response = self.client.send_to_device(&request).await?;
self.client
.base_client
.mark_request_as_sent(&request.txn_id, &response)
.await?;
self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?;
}
Ok(())
@ -403,10 +385,7 @@ impl Joined {
self.preshare_group_session().await?;
AnyMessageEventContent::RoomEncrypted(
self.client
.base_client
.encrypt(self.inner.room_id(), content)
.await?,
self.client.base_client.encrypt(self.inner.room_id(), content).await?,
)
} else {
content.into()
@ -426,8 +405,9 @@ impl Joined {
/// If the room is encrypted and the encryption feature is enabled the
/// upload will be encrypted.
///
/// This is a convenience method that calls the [`Client::upload()`](#Client::method.upload)
/// and afterwards the [`send()`](#method.send).
/// This is a convenience method that calls the
/// [`Client::upload()`](#Client::method.upload) and afterwards the
/// [`send()`](#method.send).
///
/// # Arguments
/// * `body` - A textual representation of the media that is going to be
@ -534,11 +514,8 @@ impl Joined {
}),
};
self.send(
AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)),
txn_id,
)
.await
self.send(AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), txn_id)
.await
}
/// Send a room state event to the homeserver.
@ -635,10 +612,10 @@ impl Joined {
txn_id: Option<Uuid>,
) -> Result<redact_event::Response> {
let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string();
let request = assign!(
redact_event::Request::new(self.inner.room_id(), event_id, &txn_id),
{ reason }
);
let request =
assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), {
reason
});
self.client.send(request, None).await
}

View file

@ -23,9 +23,7 @@ impl Left {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Left {
Some(Self {
inner: Common::new(client, room),
})
Some(Self { inner: Common::new(client, room) })
} else {
None
}

View file

@ -21,10 +21,7 @@ impl Deref for RoomMember {
impl RoomMember {
pub(crate) fn new(client: Client, member: BaseRoomMember) -> Self {
Self {
inner: member,
client,
}
Self { inner: member, client }
}
/// Gets the avatar of this member, if set.

View file

@ -26,11 +26,7 @@ impl AppserviceEventHandler {
#[async_trait]
impl EventHandler for AppserviceEventHandler {
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
if !self
.appservice
.user_id_is_in_namespace(&event.state_key)
.unwrap()
{
if !self.appservice.user_id_is_in_namespace(&event.state_key).unwrap() {
dbg!("not an appservice user");
return;
}
@ -38,11 +34,7 @@ impl EventHandler for AppserviceEventHandler {
if let MembershipState::Invite = event.content.membership {
let user_id = UserId::try_from(event.state_key.clone()).unwrap();
let client = self
.appservice
.client_with_localpart(user_id.localpart())
.await
.unwrap();
let client = self.appservice.client_with_localpart(user_id.localpart()).await.unwrap();
client.join_room_by_id(room.room_id()).await.unwrap();
}
@ -51,10 +43,7 @@ impl EventHandler for AppserviceEventHandler {
#[actix_web::main]
pub async fn main() -> std::io::Result<()> {
env::set_var(
"RUST_LOG",
"actix_web=debug,actix_server=info,matrix_sdk=debug",
);
env::set_var("RUST_LOG", "actix_web=debug,actix_server=info,matrix_sdk=debug");
tracing_subscriber::fmt::init();
let homeserver_url = "http://localhost:8008";
@ -62,16 +51,11 @@ pub async fn main() -> std::io::Result<()> {
let registration =
AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap();
let appservice = Appservice::new(homeserver_url, server_name, registration)
.await
.unwrap();
let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap();
let event_handler = AppserviceEventHandler::new(appservice.clone());
appservice
.client()
.set_event_handler(Box::new(event_handler))
.await;
appservice.client().set_event_handler(Box::new(event_handler)).await;
HttpServer::new(move || App::new().service(appservice.actix_service()))
.bind(("0.0.0.0", 8090))?

View file

@ -52,10 +52,7 @@ pub fn get_scope() -> Scope {
}
fn gen_scope(scope: &str) -> Scope {
web::scope(scope)
.service(push_transactions)
.service(query_user_id)
.service(query_room_alias)
web::scope(scope).service(push_transactions).service(query_user_id).service(query_room_alias)
}
#[tracing::instrument]
@ -68,11 +65,7 @@ async fn push_transactions(
return Ok(HttpResponse::Unauthorized().finish());
}
appservice
.client()
.receive_transaction(request.incoming)
.await
.unwrap();
appservice.client().receive_transaction(request.incoming).await.unwrap();
Ok(HttpResponse::Ok().json("{}"))
}
@ -135,13 +128,9 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
uri
};
let mut builder = http::request::Builder::new()
.method(request.method())
.uri(uri);
let mut builder = http::request::Builder::new().method(request.method()).uri(uri);
let headers = builder
.headers_mut()
.ok_or(Error::UnknownHttpRequestBuilder)?;
let headers = builder.headers_mut().ok_or(Error::UnknownHttpRequestBuilder)?;
for (key, value) in request.headers().iter() {
headers.append(key, value.to_owned());
}
@ -157,10 +146,7 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
let access_token = match request.uri().query() {
Some(query) => {
let query: Vec<(String, String)> = matrix_sdk::urlencoded::from_str(query)?;
query
.into_iter()
.find(|(key, _)| key == "access_token")
.map(|(_, value)| value)
query.into_iter().find(|(key, _)| key == "access_token").map(|(_, value)| value)
}
None => None,
};

View file

@ -104,9 +104,7 @@ impl AppserviceRegistration {
///
/// See the fields of [`Registration`] for the required format
pub fn try_from_yaml_str(value: impl AsRef<str>) -> Result<Self> {
Ok(Self {
inner: serde_yaml::from_str(value.as_ref())?,
})
Ok(Self { inner: serde_yaml::from_str(value.as_ref())? })
}
/// Try to load registration from yaml file
@ -115,9 +113,7 @@ impl AppserviceRegistration {
pub fn try_from_yaml_file(path: impl Into<PathBuf>) -> Result<Self> {
let file = File::open(path.into())?;
Ok(Self {
inner: serde_yaml::from_reader(file)?,
})
Ok(Self { inner: serde_yaml::from_reader(file)? })
}
}

View file

@ -6,10 +6,7 @@ mod actix {
use matrix_sdk_appservice::*;
async fn appservice() -> Appservice {
env::set_var(
"RUST_LOG",
"mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug",
);
env::set_var("RUST_LOG", "mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug");
let _ = tracing_subscriber::fmt::try_init();
Appservice::new(

View file

@ -76,10 +76,7 @@ async fn test_event_handler() -> Result<()> {
}
}
appservice
.client()
.set_event_handler(Box::new(Example::new()))
.await;
appservice.client().set_event_handler(Box::new(Example::new())).await;
let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap();
let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into();

View file

@ -76,17 +76,8 @@ impl InspectorHelper {
fn complete_event_types(&self, arg: Option<&&str>) -> Vec<Pair> {
Self::EVENT_TYPES
.iter()
.map(|t| Pair {
display: t.to_string(),
replacement: format!("{} ", t),
})
.filter(|r| {
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.map(|t| Pair { display: t.to_string(), replacement: format!("{} ", t) })
.filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
.collect()
}
@ -99,13 +90,7 @@ impl InspectorHelper {
display: r.room_id.to_string(),
replacement: format!("{} ", r.room_id.to_string()),
})
.filter(|r| {
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
.collect()
}
}
@ -124,15 +109,9 @@ impl Completer for InspectorHelper {
let commands = vec![
("get-state", "get a state event in the given room"),
(
"get-profiles",
"get all the stored profiles in the given room",
),
("get-profiles", "get all the stored profiles in the given room"),
("list-rooms", "list all rooms"),
(
"get-members",
"get all the membership events in the given room",
),
("get-members", "get all the membership events in the given room"),
]
.iter()
.map(|(r, d)| Pair {
@ -151,19 +130,13 @@ impl Completer for InspectorHelper {
} else {
Ok((
0,
commands
.into_iter()
.filter(|c| c.replacement.starts_with(args[0]))
.collect(),
commands.into_iter().filter(|c| c.replacement.starts_with(args[0])).collect(),
))
}
} else if args.len() == 2 {
if args[0] == "get-state" {
if line.ends_with(' ') {
Ok((
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
} else {
Ok((args[0].len() + 1, self.complete_rooms(args.get(1))))
}
@ -174,10 +147,7 @@ impl Completer for InspectorHelper {
}
} else if args.len() == 3 {
if args[0] == "get-state" {
Ok((
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
} else {
Ok((pos, vec![]))
}
@ -213,12 +183,7 @@ impl Printer {
let syntax_set: SyntaxSet = from_binary(include_bytes!("./syntaxes.bin"));
let themes: ThemeSet = from_binary(include_bytes!("./themes.bin"));
Self {
ps: syntax_set.into(),
ts: themes.into(),
json,
color,
}
Self { ps: syntax_set.into(), ts: themes.into(), json, color }
}
fn pretty_print_struct<T: Debug + Serialize>(&self, data: &T) {
@ -229,13 +194,9 @@ impl Printer {
};
let syntax = if self.json {
self.ps
.find_syntax_by_extension("rs")
.expect("Can't find rust syntax extension")
self.ps.find_syntax_by_extension("rs").expect("Can't find rust syntax extension")
} else {
self.ps
.find_syntax_by_extension("json")
.expect("Can't find json syntax extension")
self.ps.find_syntax_by_extension("json").expect("Can't find json syntax extension")
};
if self.color {
@ -302,11 +263,7 @@ impl Inspector {
}
async fn get_display_name_owners(&self, room_id: RoomId, display_name: String) {
let users = self
.store
.get_users_with_display_name(&room_id, &display_name)
.await
.unwrap();
let users = self.store.get_users_with_display_name(&room_id, &display_name).await.unwrap();
self.printer.pretty_print_struct(&users);
}
@ -323,22 +280,14 @@ impl Inspector {
let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
for member in joined {
let event = self
.store
.get_member_event(&room_id, &member)
.await
.unwrap();
let event = self.store.get_member_event(&room_id, &member).await.unwrap();
self.printer.pretty_print_struct(&event);
}
}
async fn get_state(&self, room_id: RoomId, event_type: EventType) {
self.printer.pretty_print_struct(
&self
.store
.get_state_event(&room_id, event_type, "")
.await
.unwrap(),
&self.store.get_state_event(&room_id, event_type, "").await.unwrap(),
);
}
@ -347,35 +296,25 @@ impl Inspector {
SubCommand::with_name("list-rooms"),
SubCommand::with_name("get-members").arg(
Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}),
),
SubCommand::with_name("get-profiles").arg(
Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}),
),
SubCommand::with_name("get-display-names")
.arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}))
.arg(Arg::with_name("display-name").required(true)),
SubCommand::with_name("get-state")
.arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}))
.arg(Arg::with_name("event-type").required(true).validator(|e| {
EventType::try_from(e)
.map(|_| ())
.map_err(|_| "Invalid event type".to_string())
EventType::try_from(e).map(|_| ()).map_err(|_| "Invalid event type".to_string())
})),
]
}

View file

@ -87,20 +87,21 @@ pub struct AdditionalUnsignedData {
pub prev_content: Option<Raw<MemberEventContent>>,
}
/// Transform state event by hoisting `prev_content` field from `unsigned` to the top level.
/// Transform state event by hoisting `prev_content` field from `unsigned` to
/// the top level.
///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in `unsigned` contrary to
/// the C2S spec. Some more discussion can be found [here][discussion]. Until this is fixed in
/// synapse or handled in Ruma, we use this to hoist up `prev_content` to the top level.
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// `unsigned` contrary to the C2S spec. Some more discussion can be found
/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
/// use this to hoist up `prev_content` to the top level.
///
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
pub fn hoist_and_deserialize_state_event(
event: &Raw<AnySyncStateEvent>,
) -> StdResult<AnySyncStateEvent, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())?
.unsigned
.prev_content;
let prev_content =
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
let mut ev = event.deserialize()?;
@ -116,9 +117,8 @@ pub fn hoist_and_deserialize_state_event(
fn hoist_member_event(
event: &Raw<StateEvent<MemberEventContent>>,
) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())?
.unsigned
.prev_content;
let prev_content =
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
let mut e = event.deserialize()?;
@ -340,7 +340,8 @@ impl BaseClient {
///
/// # Arguments
///
/// * `response` - A successful login response that contains our access token
/// * `response` - A successful login response that contains our access
/// token
/// and device id.
pub async fn receive_login_response(
&self,
@ -440,9 +441,7 @@ impl BaseClient {
AnySyncRoomEvent::State(s) => match s {
AnySyncStateEvent::RoomMember(member) => {
if let Ok(member) = MemberEvent::try_from(member.clone()) {
ambiguity_cache
.handle_event(changes, room_id, &member)
.await?;
ambiguity_cache.handle_event(changes, room_id, &member).await?;
match member.content.membership {
MembershipState::Join | MembershipState::Invite => {
@ -500,8 +499,7 @@ impl BaseClient {
}
if let Some(context) = &mut push_context {
self.update_push_room_context(context, user_id, room_info, changes)
.await;
self.update_push_room_context(context, user_id, room_info, changes).await;
} else {
push_context = self.get_push_room_context(room, room_info, changes).await?;
}
@ -521,10 +519,13 @@ impl BaseClient {
),
);
}
// TODO if there is an Action::SetTweak(Tweak::Highlight) we need to store
// its value with the event so a client can show if the event is highlighted
// TODO if there is an
// Action::SetTweak(Tweak::Highlight) we need to store
// its value with the event so a client can show if the
// event is highlighted
// in the UI.
// Requires the possibility to associate custom data with events and to
// Requires the possibility to associate custom data
// with events and to
// store them.
}
}
@ -762,18 +763,14 @@ impl BaseClient {
let mut changes = StateChanges::new(next_batch.clone());
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
self.handle_account_data(&account_data.events, &mut changes)
.await;
self.handle_account_data(&account_data.events, &mut changes).await;
let push_rules = self.get_push_rules(&changes).await?;
let mut new_rooms = Rooms::default();
for (room_id, new_info) in rooms.join {
let room = self
.store
.get_or_create_room(&room_id, RoomType::Joined)
.await;
let room = self.store.get_or_create_room(&room_id, RoomType::Joined).await;
let mut room_info = room.clone_info();
room_info.mark_as_joined();
@ -844,10 +841,7 @@ impl BaseClient {
}
for (room_id, new_info) in rooms.leave {
let room = self
.store
.get_or_create_room(&room_id, RoomType::Left)
.await;
let room = self.store.get_or_create_room(&room_id, RoomType::Left).await;
let mut room_info = room.clone_info();
room_info.mark_as_left();
@ -876,18 +870,14 @@ impl BaseClient {
.await;
changes.add_room(room_info);
new_rooms.leave.insert(
room_id,
LeftRoom::new(timeline, new_info.state, new_info.account_data),
);
new_rooms
.leave
.insert(room_id, LeftRoom::new(timeline, new_info.state, new_info.account_data));
}
for (room_id, new_info) in rooms.invite {
{
let room = self
.store
.get_or_create_room(&room_id, RoomType::Invited)
.await;
let room = self.store.get_or_create_room(&room_id, RoomType::Invited).await;
let mut room_info = room.clone_info();
room_info.mark_as_invited();
changes.add_room(room_info);
@ -934,9 +924,7 @@ impl BaseClient {
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect(),
ambiguity_changes: AmbiguityChanges {
changes: ambiguity_cache.changes,
},
ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
notifications: changes.notifications,
};
@ -968,11 +956,7 @@ impl BaseClient {
let members: Vec<MemberEvent> = response
.chunk
.iter()
.filter_map(|e| {
hoist_member_event(e)
.ok()
.and_then(|e| MemberEvent::try_from(e).ok())
})
.filter_map(|e| hoist_member_event(e).ok().and_then(|e| MemberEvent::try_from(e).ok()))
.collect();
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
@ -986,12 +970,7 @@ impl BaseClient {
let mut user_ids = BTreeSet::new();
for member in &members {
if self
.store
.get_member_event(&room_id, &member.state_key)
.await?
.is_none()
{
if self.store.get_member_event(&room_id, &member.state_key).await?.is_none() {
#[cfg(feature = "encryption")]
match member.content.membership {
MembershipState::Join | MembershipState::Invite => {
@ -1000,9 +979,7 @@ impl BaseClient {
_ => (),
}
ambiguity_cache
.handle_event(&changes, room_id, &member)
.await?;
ambiguity_cache.handle_event(&changes, room_id, &member).await?;
if member.state_key == member.sender {
changes
@ -1036,9 +1013,7 @@ impl BaseClient {
Ok(MembersResponse {
chunk: members,
ambiguity_changes: AmbiguityChanges {
changes: ambiguity_cache.changes,
},
ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
})
}
@ -1050,7 +1025,8 @@ impl BaseClient {
///
/// # Arguments
///
/// * `filter_name` - The name that should be used to persist the filter id in
/// * `filter_name` - The name that should be used to persist the filter id
/// in
/// the store.
///
/// * `response` - The successful filter upload response containing the
@ -1062,10 +1038,7 @@ impl BaseClient {
filter_name: &str,
response: &api::filter::create_filter::Response,
) -> Result<()> {
Ok(self
.store
.save_filter(filter_name, &response.filter_id)
.await?)
Ok(self.store.save_filter(filter_name, &response.filter_id).await?)
}
/// Get the filter id of a previously uploaded filter.
@ -1223,18 +1196,15 @@ impl BaseClient {
///
/// # Arguments
///
/// * `flow_id` - The unique id that identifies a interactive verification flow. For in-room
/// verifications this will be the event id of the *m.key.verification.request* event that
/// started the flow, for the to-device verification flows this will be the transaction id of
/// the *m.key.verification.start* event.
/// * `flow_id` - The unique id that identifies a interactive verification
/// flow. For in-room verifications this will be the event id of the
/// *m.key.verification.request* event that started the flow, for the
/// to-device verification flows this will be the transaction id of the
/// *m.key.verification.start* event.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
self.olm
.lock()
.await
.as_ref()
.and_then(|o| o.get_verification(flow_id))
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
}
/// Get a specific device of a user.
@ -1283,10 +1253,12 @@ impl BaseClient {
/// Get the user login session.
///
/// If the client is currently logged in, this will return a `matrix_sdk::Session` object which
/// can later be given to `restore_login`.
/// If the client is currently logged in, this will return a
/// `matrix_sdk::Session` object which can later be given to
/// `restore_login`.
///
/// Returns a session object if the client is logged in. Otherwise returns `None`.
/// Returns a session object if the client is logged in. Otherwise returns
/// `None`.
pub async fn get_session(&self) -> Option<Session> {
self.session.read().await.clone()
}
@ -1348,8 +1320,9 @@ impl BaseClient {
/// Get the push rules.
///
/// Gets the push rules from `changes` if they have been updated, otherwise get them from the
/// store. As a fallback, uses `Ruleset::server_default` if the user is logged in.
/// Gets the push rules from `changes` if they have been updated, otherwise
/// get them from the store. As a fallback, uses
/// `Ruleset::server_default` if the user is logged in.
pub async fn get_push_rules(&self, changes: &StateChanges) -> Result<Ruleset> {
if let Some(AnyGlobalAccountDataEvent::PushRules(event)) = changes
.account_data
@ -1373,11 +1346,11 @@ impl BaseClient {
/// Get the push context for the given room.
///
/// Tries to get the data from `changes` or the up to date `room_info`. Loads the data from the
/// store otherwise.
/// Tries to get the data from `changes` or the up to date `room_info`.
/// Loads the data from the store otherwise.
///
/// Returns `None` if some data couldn't be found. This should only happen in brand new rooms,
/// while we process its state.
/// Returns `None` if some data couldn't be found. This should only happen
/// in brand new rooms, while we process its state.
pub async fn get_push_room_context(
&self,
room: &Room,
@ -1389,16 +1362,10 @@ impl BaseClient {
let member_count = room_info.active_members_count();
let user_display_name = if let Some(member) = changes
.members
.get(room_id)
.and_then(|members| members.get(user_id))
let user_display_name = if let Some(member) =
changes.members.get(room_id).and_then(|members| members.get(user_id))
{
member
.content
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
} else if let Some(member) = room.get_member(user_id).await? {
member.name().to_owned()
} else {
@ -1448,16 +1415,10 @@ impl BaseClient {
push_rules.member_count = UInt::new(room_info.active_members_count()).unwrap_or(UInt::MAX);
if let Some(member) = changes
.members
.get(room_id)
.and_then(|members| members.get(user_id))
if let Some(member) = changes.members.get(room_id).and_then(|members| members.get(user_id))
{
push_rules.user_display_name = member
.content
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
push_rules.user_display_name =
member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
}
if let Some(AnySyncStateEvent::RoomPowerLevels(event)) = changes

View file

@ -28,7 +28,8 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
/// Internal representation of errors.
#[derive(Error, Debug)]
pub enum Error {
/// Queried endpoint requires authentication but was called on an anonymous client.
/// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired,

View file

@ -66,20 +66,12 @@ impl BaseRoomInfo {
let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1);
if heroes_count >= invited_joined {
let mut names = heroes
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
// stabilize ordering
names.sort_unstable();
names.join(", ")
} else if heroes_count < invited_joined && invited_joined > 1 {
let mut names = heroes
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
names.sort_unstable();
// TODO: What length does the spec want us to use here and in
@ -144,10 +136,8 @@ impl BaseRoomInfo {
true
}
AnyStateEventContent::RoomPowerLevels(p) => {
let max_power_level = p
.users
.values()
.fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
let max_power_level =
p.users.values().fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
self.max_power_level = max_power_level;
true
}

View file

@ -43,7 +43,8 @@ use crate::{
store::{Result as StoreResult, StateStore},
};
/// The underlying room data structure collecting state for joined, left and invtied rooms.
/// The underlying room data structure collecting state for joined, left and
/// invtied rooms.
#[derive(Debug, Clone)]
pub struct Room {
room_id: Arc<RoomId>,
@ -134,7 +135,8 @@ impl Room {
/// Check if the room has it's members fully synced.
///
/// Members might be missing if lazy member loading was enabled for the sync.
/// Members might be missing if lazy member loading was enabled for the
/// sync.
///
/// Returns true if no members are missing, false otherwise.
pub fn are_members_synced(&self) -> bool {
@ -199,12 +201,7 @@ impl Room {
/// Get the history visibility policy of this room.
pub fn history_visibility(&self) -> HistoryVisibility {
self.inner
.read()
.unwrap()
.base_info
.history_visibility
.clone()
self.inner.read().unwrap().base_info.history_visibility.clone()
}
/// Is the room considered to be public.
@ -366,9 +363,7 @@ impl Room {
);
let inner = self.inner.read().unwrap();
Ok(inner
.base_info
.calculate_room_name(joined, invited, members))
Ok(inner.base_info.calculate_room_name(joined, invited, members))
}
pub(crate) fn clone_info(&self) -> RoomInfo {
@ -393,11 +388,8 @@ impl Room {
return Ok(None);
};
let presence = self
.store
.get_presence_event(user_id)
.await?
.and_then(|e| e.deserialize().ok());
let presence =
self.store.get_presence_event(user_id).await?.and_then(|e| e.deserialize().ok());
let profile = self.store.get_profile(self.room_id(), user_id).await?;
let max_power_level = self.max_power_level();
let is_room_creator = self
@ -410,28 +402,24 @@ impl Room {
.map(|c| &c.creator == user_id)
.unwrap_or(false);
let power = self
.store
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
.await?
.and_then(|e| e.deserialize().ok())
.and_then(|e| {
if let AnySyncStateEvent::RoomPowerLevels(e) = e {
Some(e)
} else {
None
}
});
let power =
self.store
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
.await?
.and_then(|e| e.deserialize().ok())
.and_then(|e| {
if let AnySyncStateEvent::RoomPowerLevels(e) = e {
Some(e)
} else {
None
}
});
let ambiguous = self
.store
.get_users_with_display_name(
self.room_id(),
member_event
.content
.displayname
.as_deref()
.unwrap_or_else(|| user_id.localpart()),
member_event.content.displayname.as_deref().unwrap_or_else(|| user_id.localpart()),
)
.await?
.len()
@ -557,8 +545,6 @@ impl RoomInfo {
///
/// The return value is saturated at `u64::MAX`.
pub fn active_members_count(&self) -> u64 {
self.summary
.joined_member_count
.saturating_add(self.summary.invited_member_count)
self.summary.joined_member_count.saturating_add(self.summary.invited_member_count)
}
}

View file

@ -49,11 +49,8 @@ impl AmbiguityMap {
}
fn add(&mut self, user_id: UserId) -> Option<UserId> {
let ambiguous_user = if self.user_count() == 1 {
self.users.iter().next().cloned()
} else {
None
};
let ambiguous_user =
if self.user_count() == 1 { self.users.iter().next().cloned() } else { None };
self.users.insert(user_id);
@ -71,11 +68,7 @@ impl AmbiguityMap {
impl AmbiguityCache {
pub fn new(store: Store) -> Self {
Self {
store,
cache: BTreeMap::new(),
changes: BTreeMap::new(),
}
Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() }
}
pub async fn handle_event(
@ -113,12 +106,9 @@ impl AmbiguityCache {
return Ok(());
}
let disambiguated_member = old_map
.as_mut()
.and_then(|o| o.remove(&member_event.state_key));
let ambiguated_member = new_map
.as_mut()
.and_then(|n| n.add(member_event.state_key.clone()));
let disambiguated_member = old_map.as_mut().and_then(|o| o.remove(&member_event.state_key));
let ambiguated_member =
new_map.as_mut().and_then(|n| n.add(member_event.state_key.clone()));
let ambiguous = new_map.as_ref().map(|n| n.is_ambiguous()).unwrap_or(false);
self.update(room_id, old_map, new_map);
@ -129,11 +119,7 @@ impl AmbiguityCache {
member_ambiguous: ambiguous,
};
trace!(
"Handling display name ambiguity for {}: {:#?}",
member_event.state_key,
change
);
trace!("Handling display name ambiguity for {}: {:#?}", member_event.state_key, change);
self.add_change(room_id, member_event.event_id.clone(), change);
@ -146,10 +132,7 @@ impl AmbiguityCache {
old_map: Option<AmbiguityMap>,
new_map: Option<AmbiguityMap>,
) {
let entry = self
.cache
.entry(room_id.clone())
.or_insert_with(BTreeMap::new);
let entry = self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new);
if let Some(old) = old_map {
entry.insert(old.display_name, old.users);
@ -161,10 +144,7 @@ impl AmbiguityCache {
}
fn add_change(&mut self, room_id: &RoomId, event_id: EventId, change: AmbiguityChange) {
self.changes
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.insert(event_id, change);
self.changes.entry(room_id.clone()).or_insert_with(BTreeMap::new).insert(event_id, change);
}
async fn get(
@ -175,16 +155,12 @@ impl AmbiguityCache {
) -> Result<(Option<AmbiguityMap>, Option<AmbiguityMap>)> {
use MembershipState::*;
let old_event = if let Some(m) = changes
.members
.get(room_id)
.and_then(|m| m.get(&member_event.state_key))
let old_event = if let Some(m) =
changes.members.get(room_id).and_then(|m| m.get(&member_event.state_key))
{
Some(m.clone())
} else {
self.store
.get_member_event(room_id, &member_event.state_key)
.await?
self.store.get_member_event(room_id, &member_event.state_key).await?
};
let old_display_name = if let Some(event) = old_event {
@ -216,23 +192,15 @@ impl AmbiguityCache {
};
let old_map = if let Some(old_name) = old_display_name.as_deref() {
let old_display_name_map = if let Some(u) = self
.cache
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.get(old_name)
let old_display_name_map = if let Some(u) =
self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new).get(old_name)
{
u.clone()
} else {
self.store
.get_users_with_display_name(&room_id, &old_name)
.await?
self.store.get_users_with_display_name(&room_id, &old_name).await?
};
Some(AmbiguityMap {
display_name: old_name.to_string(),
users: old_display_name_map,
})
Some(AmbiguityMap { display_name: old_name.to_string(), users: old_display_name_map })
} else {
None
};
@ -244,8 +212,9 @@ impl AmbiguityCache {
.as_deref()
.unwrap_or_else(|| member_event.state_key.localpart());
// We don't allow other users to set the display name, so if we have
// a more trusted version of the display name use that.
// We don't allow other users to set the display name, so if we
// have a more trusted version of the display
// name use that.
let new_display_name = if member_event.sender.as_str() == member_event.state_key {
new
} else if let Some(old) = old_display_name.as_deref() {
@ -262,9 +231,7 @@ impl AmbiguityCache {
{
u.clone()
} else {
self.store
.get_users_with_display_name(&room_id, &new_display_name)
.await?
self.store.get_users_with_display_name(&room_id, &new_display_name).await?
};
Some(AmbiguityMap {

View file

@ -80,8 +80,7 @@ impl MemoryStore {
}
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.filters
.insert(filter_name.to_string(), filter_id.to_string());
self.filters.insert(filter_name.to_string(), filter_id.to_string());
Ok(())
}
@ -162,8 +161,7 @@ impl MemoryStore {
}
for (event_type, event) in &changes.account_data {
self.account_data
.insert(event_type.to_string(), event.clone());
self.account_data.insert(event_type.to_string(), event.clone());
}
for (room, events) in &changes.room_account_data {
@ -197,8 +195,7 @@ impl MemoryStore {
}
for (room_id, info) in &changes.invited_room_info {
self.stripped_room_info
.insert(room_id.clone(), info.clone());
self.stripped_room_info.insert(room_id.clone(), info.clone());
}
for (room, events) in &changes.stripped_members {
@ -241,8 +238,7 @@ impl MemoryStore {
) -> Result<Option<Raw<AnySyncStateEvent>>> {
#[allow(clippy::map_clone)]
Ok(self.room_state.get(room_id).and_then(|e| {
e.get(event_type.as_ref())
.and_then(|s| s.get(state_key).map(|e| e.clone()))
e.get(event_type.as_ref()).and_then(|s| s.get(state_key).map(|e| e.clone()))
}))
}
@ -252,10 +248,7 @@ impl MemoryStore {
user_id: &UserId,
) -> Result<Option<MemberEventContent>> {
#[allow(clippy::map_clone)]
Ok(self
.profiles
.get(room_id)
.and_then(|p| p.get(user_id).map(|p| p.clone())))
Ok(self.profiles.get(room_id).and_then(|p| p.get(user_id).map(|p| p.clone())))
}
async fn get_member_event(
@ -264,10 +257,7 @@ impl MemoryStore {
state_key: &UserId,
) -> Result<Option<MemberEvent>> {
#[allow(clippy::map_clone)]
Ok(self
.members
.get(room_id)
.and_then(|m| m.get(state_key).map(|m| m.clone())))
Ok(self.members.get(room_id).and_then(|m| m.get(state_key).map(|m| m.clone())))
}
fn get_user_ids(&self, room_id: &RoomId) -> Vec<UserId> {
@ -308,10 +298,7 @@ impl MemoryStore {
&self,
event_type: EventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
Ok(self
.account_data
.get(event_type.as_ref())
.map(|e| e.clone()))
Ok(self.account_data.get(event_type.as_ref()).map(|e| e.clone()))
}
async fn get_room_account_data_event(

View file

@ -200,7 +200,8 @@ pub trait StateStore: AsyncTraitDeps {
///
/// # Arguments
///
/// * `room_id` - The id of the room for which the room account data event should
/// * `room_id` - The id of the room for which the room account data event
/// should
/// be fetched.
///
/// * `event_type` - The event type of the room account data event.
@ -297,20 +298,16 @@ impl Store {
/// Get all the rooms this store knows about.
pub fn get_rooms(&self) -> Vec<Room> {
self.rooms
.iter()
.filter_map(|r| self.get_room(r.key()))
.collect()
self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect()
}
/// Get the room with the given room id.
pub fn get_room(&self, room_id: &RoomId) -> Option<Room> {
self.get_bare_room(room_id)
.and_then(|r| match r.room_type() {
RoomType::Joined => Some(r),
RoomType::Left => Some(r),
RoomType::Invited => self.get_stripped_room(room_id),
})
self.get_bare_room(room_id).and_then(|r| match r.room_type() {
RoomType::Joined => Some(r),
RoomType::Left => Some(r),
RoomType::Invited => self.get_stripped_room(room_id),
})
}
fn get_stripped_room(&self, room_id: &RoomId) -> Option<Room> {
@ -320,10 +317,7 @@ impl Store {
pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room {
let session = self.session.read().await;
let user_id = &session
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
self.stripped_rooms
.entry(room_id.clone())
@ -333,10 +327,7 @@ impl Store {
pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room {
let session = self.session.read().await;
let user_id = &session
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
self.rooms
.entry(room_id.clone())
@ -358,8 +349,8 @@ impl Deref for Store {
pub struct StateChanges {
/// The sync token that relates to this update.
pub sync_token: Option<String>,
/// A user session, containing an access token and information about the associated user
/// account.
/// A user session, containing an access token and information about the
/// associated user account.
pub session: Option<Session>,
/// A mapping of event type string to `AnyBasicEvent`.
pub account_data: BTreeMap<String, Raw<AnyGlobalAccountDataEvent>>,
@ -371,7 +362,8 @@ pub struct StateChanges {
/// A mapping of `RoomId` to a map of users and their `MemberEventContent`.
pub profiles: BTreeMap<RoomId, BTreeMap<UserId, MemberEventContent>>,
/// A mapping of `RoomId` to a map of event type string to a state key and `AnySyncStateEvent`.
/// A mapping of `RoomId` to a map of event type string to a state key and
/// `AnySyncStateEvent`.
pub state: BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnySyncStateEvent>>>>,
/// A mapping of `RoomId` to a map of event type string to `AnyBasicEvent`.
pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>,
@ -397,10 +389,7 @@ pub struct StateChanges {
impl StateChanges {
/// Create a new `StateChanges` struct with the given sync_token.
pub fn new(sync_token: String) -> Self {
Self {
sync_token: Some(sync_token),
..Default::default()
}
Self { sync_token: Some(sync_token), ..Default::default() }
}
/// Update the `StateChanges` struct with the given `PresenceEvent`.
@ -410,14 +399,12 @@ impl StateChanges {
/// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_room(&mut self, room: RoomInfo) {
self.room_infos
.insert(room.room_id.as_ref().to_owned(), room);
self.room_infos.insert(room.room_id.as_ref().to_owned(), room);
}
/// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_stripped_room(&mut self, room: RoomInfo) {
self.invited_room_info
.insert(room.room_id.as_ref().to_owned(), room);
self.invited_room_info.insert(room.room_id.as_ref().to_owned(), room);
}
/// Update the `StateChanges` struct with the given `AnyBasicEvent`.
@ -426,11 +413,11 @@ impl StateChanges {
event: AnyGlobalAccountDataEvent,
raw_event: Raw<AnyGlobalAccountDataEvent>,
) {
self.account_data
.insert(event.content().event_type().to_owned(), raw_event);
self.account_data.insert(event.content().event_type().to_owned(), raw_event);
}
/// Update the `StateChanges` struct with the given room with a new `AnyBasicEvent`.
/// Update the `StateChanges` struct with the given room with a new
/// `AnyBasicEvent`.
pub fn add_room_account_data(
&mut self,
room_id: &RoomId,
@ -443,7 +430,8 @@ impl StateChanges {
.insert(event.content().event_type().to_owned(), raw_event);
}
/// Update the `StateChanges` struct with the given room with a new `StrippedMemberEvent`.
/// Update the `StateChanges` struct with the given room with a new
/// `StrippedMemberEvent`.
pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) {
let user_id = event.state_key.clone();
@ -453,7 +441,8 @@ impl StateChanges {
.insert(user_id, event);
}
/// Update the `StateChanges` struct with the given room with a new `AnySyncStateEvent`.
/// Update the `StateChanges` struct with the given room with a new
/// `AnySyncStateEvent`.
pub fn add_state_event(
&mut self,
room_id: &RoomId,
@ -468,11 +457,9 @@ impl StateChanges {
.insert(event.state_key().to_string(), raw_event);
}
/// Update the `StateChanges` struct with the given room with a new `Notification`.
/// Update the `StateChanges` struct with the given room with a new
/// `Notification`.
pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) {
self.notifications
.entry(room_id.to_owned())
.or_insert_with(Vec::new)
.push(notification);
self.notifications.entry(room_id.to_owned()).or_insert_with(Vec::new).push(notification);
}
}

View file

@ -108,13 +108,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> {
[
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
}
}
@ -164,9 +158,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish()
} else {
f.debug_struct("SledStore")
.field("path", &"memory store")
.finish()
f.debug_struct("SledStore").field("path", &"memory store").finish()
}
}
}
@ -236,8 +228,7 @@ impl SledStore {
} else {
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
let encrypted_key = DatabaseType::Encrypted(
key.export(passphrase)
.map_err::<StoreError, _>(|e| e.into())?,
key.export(passphrase).map_err::<StoreError, _>(|e| e.into())?,
);
db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?;
key
@ -275,8 +266,7 @@ impl SledStore {
}
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.session
.insert(("filter", filter_name).encode(), filter_id)?;
self.session.insert(("filter", filter_name).encode(), filter_id)?;
Ok(())
}
@ -476,11 +466,7 @@ impl SledStore {
}
pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
Ok(self
.presence
.get(user_id.encode())?
.map(|e| self.deserialize_event(&e))
.transpose()?)
Ok(self.presence.get(user_id.encode())?.map(|e| self.deserialize_event(&e)).transpose()?)
}
pub async fn get_state_event(
@ -531,14 +517,10 @@ impl SledStore {
&self,
room_id: &RoomId,
) -> impl Stream<Item = Result<UserId>> {
stream::iter(
self.invited_user_ids
.scan_prefix(room_id.encode())
.map(|u| {
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier)
}),
)
stream::iter(self.invited_user_ids.scan_prefix(room_id.encode()).map(|u| {
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier)
}))
}
pub async fn get_joined_user_ids(
@ -554,9 +536,7 @@ impl SledStore {
pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
let db = self.clone();
stream::iter(
self.room_info
.iter()
.map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
self.room_info.iter().map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
)
}
@ -680,8 +660,7 @@ impl StateStore for SledStore {
room_id: &RoomId,
display_name: &str,
) -> Result<BTreeSet<UserId>> {
self.get_users_with_display_name(room_id, display_name)
.await
self.get_users_with_display_name(room_id, display_name).await
}
async fn get_account_data_event(
@ -767,11 +746,7 @@ mod test {
let room_id = room_id!("!test:localhost");
let user_id = user_id();
assert!(store
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_none());
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_none());
let mut changes = StateChanges::default();
changes
.members
@ -780,11 +755,7 @@ mod test {
.insert(user_id.clone(), membership_event());
store.save_changes(&changes).await.unwrap();
assert!(store
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_some());
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
}
#[async_test]

View file

@ -169,10 +169,7 @@ impl StoreKey {
cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?;
Ok(EncryptedStoreKey {
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 {
rounds: KDF_ROUNDS,
kdf_salt: salt,
},
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: KDF_ROUNDS, kdf_salt: salt },
ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext },
})
}
@ -195,11 +192,7 @@ impl StoreKey {
let ciphertext = cipher.encrypt(xnonce, event.as_ref())?;
Ok(EncryptedEvent {
version: VERSION,
ciphertext,
nonce,
})
Ok(EncryptedEvent { version: VERSION, ciphertext, nonce })
}
pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> {

View file

@ -104,16 +104,14 @@ pub struct SyncRoomEvent {
impl From<Raw<AnySyncRoomEvent>> for SyncRoomEvent {
fn from(inner: Raw<AnySyncRoomEvent>) -> Self {
Self {
encryption_info: None,
event: inner,
}
Self { encryption_info: None, event: inner }
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct SyncResponse {
/// The batch token to supply in the `since` param of the next `/sync` request.
/// The batch token to supply in the `since` param of the next `/sync`
/// request.
pub next_batch: String,
/// Updates to rooms.
pub rooms: Rooms,
@ -138,10 +136,7 @@ pub struct SyncResponse {
impl SyncResponse {
pub fn new(next_batch: String) -> Self {
Self {
next_batch,
..Default::default()
}
Self { next_batch, ..Default::default() }
}
}
@ -162,14 +157,15 @@ pub struct JoinedRoom {
pub unread_notifications: UnreadNotificationsCount,
/// The timeline of messages and state changes in the room.
pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not
/// given, or `full_state` is true).
/// Updates to the state, between the time indicated by the `since`
/// parameter, and the start of the `timeline` (or all state up to the
/// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State,
/// The private data that this user has attached to this room.
pub account_data: RoomAccountData,
/// The ephemeral events in the room that aren't recorded in the timeline or state of the
/// room. e.g. typing.
/// The ephemeral events in the room that aren't recorded in the timeline or
/// state of the room. e.g. typing.
pub ephemeral: Ephemeral,
}
@ -181,20 +177,15 @@ impl JoinedRoom {
ephemeral: Ephemeral,
unread_notifications: UnreadNotificationsCount,
) -> Self {
Self {
unread_notifications,
timeline,
state,
account_data,
ephemeral,
}
Self { unread_notifications, timeline, state, account_data, ephemeral }
}
}
/// Counts of unread notifications for a room.
#[derive(Copy, Clone, Debug, Default, Deserialize, Serialize)]
pub struct UnreadNotificationsCount {
/// The number of unread notifications for this room with the highlight flag set.
/// The number of unread notifications for this room with the highlight flag
/// set.
pub highlight_count: u64,
/// The total number of unread notifications for this room.
pub notification_count: u64,
@ -204,10 +195,7 @@ impl From<RumaUnreadNotificationsCount> for UnreadNotificationsCount {
fn from(notifications: RumaUnreadNotificationsCount) -> Self {
Self {
highlight_count: notifications.highlight_count.map(|c| c.into()).unwrap_or(0),
notification_count: notifications
.notification_count
.map(|c| c.into())
.unwrap_or(0),
notification_count: notifications.notification_count.map(|c| c.into()).unwrap_or(0),
}
}
}
@ -217,9 +205,10 @@ pub struct LeftRoom {
/// The timeline of messages and state changes in the room up to the point
/// when the user left.
pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not
/// given, or `full_state` is true).
/// Updates to the state, between the time indicated by the `since`
/// parameter, and the start of the `timeline` (or all state up to the
/// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State,
/// The private data that this user has attached to this room.
pub account_data: RoomAccountData,
@ -227,18 +216,15 @@ pub struct LeftRoom {
impl LeftRoom {
pub fn new(timeline: Timeline, state: State, account_data: RoomAccountData) -> Self {
Self {
timeline,
state,
account_data,
}
Self { timeline, state, account_data }
}
}
/// Events in the room.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Timeline {
/// True if the number of events returned was limited by the `limit` on the filter.
/// True if the number of events returned was limited by the `limit` on the
/// filter.
pub limited: bool,
/// A token that can be supplied to to the `from` parameter of the
@ -251,11 +237,7 @@ pub struct Timeline {
impl Timeline {
pub fn new(limited: bool, prev_batch: Option<String>) -> Self {
Self {
limited,
prev_batch,
..Default::default()
}
Self { limited, prev_batch, ..Default::default() }
}
}

View file

@ -49,17 +49,12 @@ fn huge_keys_query_resopnse() -> get_keys::Response {
}
pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = keys_query_response();
let uuid = Uuid::new_v4();
let count = response
.device_keys
.values()
.fold(0, |acc, d| acc + d.len())
let count = response.device_keys.values().fold(0, |acc, d| acc + d.len())
+ response.master_keys.len()
+ response.self_signing_keys.len()
+ response.user_signing_keys.len();
@ -69,14 +64,10 @@ pub fn keys_query(c: &mut Criterion) {
let name = format!("{} device and cross signing keys", count);
group.bench_with_input(
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
});
let dir = tempfile::tempdir().unwrap();
let machine = runtime
@ -88,99 +79,74 @@ pub fn keys_query(c: &mut Criterion) {
))
.unwrap();
group.bench_with_input(
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
});
group.finish()
}
pub fn keys_claiming(c: &mut Criterion) {
let runtime = Arc::new(
Builder::new_multi_thread()
.build()
.expect("Can't create runtime"),
);
let runtime = Arc::new(Builder::new_multi_thread().build().expect("Can't create runtime"));
let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4();
let response = keys_claim_response();
let count = response
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Olm session creation");
group.throughput(Throughput::Elements(count as u64));
let name = format!("{} one-time keys", count);
group.bench_with_input(
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.iter_batched(
|| {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
},
BatchSize::SmallInput,
)
},
);
group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
b.iter_batched(
|| {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
},
BatchSize::SmallInput,
)
});
group.bench_with_input(
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.iter_batched(
|| {
let dir = tempfile::tempdir().unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
},
BatchSize::SmallInput,
)
},
);
group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
b.iter_batched(
|| {
let dir = tempfile::tempdir().unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
},
BatchSize::SmallInput,
)
});
group.finish()
}
pub fn room_key_sharing(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4();
@ -190,18 +156,11 @@ pub fn room_key_sharing(c: &mut Criterion) {
let to_device_response = ToDeviceResponse::new();
let users: Vec<UserId> = keys_query_response.device_keys.keys().cloned().collect();
let count = response
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
let mut group = c.benchmark_group("Room key sharing");
group.throughput(Throughput::Elements(count as u64));
@ -217,10 +176,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty());
for request in requests {
machine
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
}
machine.invalidate_group_session(&room_id).await.unwrap();
@ -236,12 +192,8 @@ pub fn room_key_sharing(c: &mut Criterion) {
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime).iter(|| async {
@ -253,10 +205,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty());
for request in requests {
machine
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
}
machine.invalidate_group_session(&room_id).await.unwrap();
@ -267,28 +216,21 @@ pub fn room_key_sharing(c: &mut Criterion) {
}
pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = huge_keys_query_resopnse();
let uuid = Uuid::new_v4();
let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
let count = response
.device_keys
.values()
.fold(0, |acc, d| acc + d.len());
let count = response.device_keys.values().fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Devices missing sessions collecting");
group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count);
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(&runtime).iter_with_large_drop(|| async {
@ -306,9 +248,7 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime)

View file

@ -45,10 +45,7 @@ pub struct FlamegraphProfiler<'a> {
impl<'a> FlamegraphProfiler<'a> {
pub fn new(frequency: c_int) -> Self {
FlamegraphProfiler {
frequency,
active_profiler: None,
}
FlamegraphProfiler { frequency, active_profiler: None }
}
}

View file

@ -55,10 +55,7 @@ impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
if hash.as_slice() == self.expected_hash.as_slice() {
Ok(0)
} else {
Err(IoError::new(
ErrorKind::Other,
"Hash missmatch while decrypting",
))
Err(IoError::new(ErrorKind::Other, "Hash missmatch while decrypting"))
}
} else {
self.sha.update(&buf[0..read_bytes]);
@ -126,23 +123,14 @@ impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
return Err(DecryptorError::UnknownVersion);
}
let hash = decode(
info.hashes
.get("sha256")
.ok_or(DecryptorError::MissingHash)?,
)?;
let hash = decode(info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?)?;
let key = Zeroizing::from(decode_url_safe(info.web_key.k)?);
let iv = decode(info.iv)?;
let sha = Sha256::default();
let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?;
Ok(AttachmentDecryptor {
inner_reader: input,
expected_hash: hash,
sha,
aes,
})
Ok(AttachmentDecryptor { inner_reader: input, expected_hash: hash, sha, aes })
}
}
@ -164,9 +152,7 @@ impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> {
if read_bytes == 0 {
let hash = self.sha.finalize_reset();
self.hashes
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
Ok(0)
} else {
self.aes.apply_keystream(&mut buf[0..read_bytes]);
@ -240,9 +226,7 @@ impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
/// Consume the encryptor and get the encryption key.
pub fn finish(mut self) -> EncryptionInfo {
let hash = self.sha.finalize();
self.hashes
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
EncryptionInfo {
version: VERSION.to_string(),

View file

@ -98,14 +98,10 @@ pub fn decrypt_key_export(
return Err(KeyExportError::InvalidHeaders);
}
let payload: String = x
.lines()
.filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER)))
.collect();
let payload: String =
x.lines().filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER))).collect();
Ok(serde_json::from_str(&decrypt_helper(
&payload, passphrase,
)?)?)
Ok(serde_json::from_str(&decrypt_helper(&payload, passphrase)?)?)
}
/// Encrypt the list of exported room keys using the given passphrase.
@ -260,10 +256,7 @@ mod test {
"};
fn export_wihtout_headers() -> String {
TEST_EXPORT
.lines()
.filter(|l| !l.starts_with("-----"))
.collect()
TEST_EXPORT.lines().filter(|l| !l.starts_with("-----")).collect()
}
#[test]
@ -300,14 +293,8 @@ mod test {
let (machine, _) = get_prepared_machine().await;
let room_id = room_id!("!test:localhost");
machine
.create_outbound_group_session_with_defaults(&room_id)
.await
.unwrap();
let export = machine
.export_keys(|s| s.room_id() == &room_id)
.await
.unwrap();
machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
let export = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap();
assert!(!export.is_empty());
@ -315,10 +302,7 @@ mod test {
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
assert_eq!(export, decrypted);
assert_eq!(
machine.import_keys(decrypted, |_, _| {}).await.unwrap(),
(0, 1)
);
assert_eq!(machine.import_keys(decrypted, |_, _| {}).await.unwrap(), (0, 1));
}
#[test]

View file

@ -113,9 +113,7 @@ pub struct Device {
impl std::fmt::Debug for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Device")
.field("device", &self.inner)
.finish()
f.debug_struct("Device").field("device", &self.inner).finish()
}
}
@ -132,10 +130,7 @@ impl Device {
///
/// Returns a `Sas` object and to-device request that needs to be sent out.
pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> {
let (sas, request) = self
.verification_machine
.start_sas(self.inner.clone())
.await?;
let (sas, request) = self.verification_machine.start_sas(self.inner.clone()).await?;
if let OutgoingVerificationRequest::ToDevice(r) = request {
Ok((sas, r))
@ -155,8 +150,7 @@ impl Device {
/// Get the trust state of the device.
pub fn trust_state(&self) -> bool {
self.inner
.trust_state(&self.own_identity, &self.device_owner_identity)
self.inner.trust_state(&self.own_identity, &self.device_owner_identity)
}
/// Set the local trust state of the device to the given state.
@ -171,10 +165,7 @@ impl Device {
self.inner.set_trust_state(trust_state);
let changes = Changes {
devices: DeviceChanges {
changed: vec![self.inner.clone()],
..Default::default()
},
devices: DeviceChanges { changed: vec![self.inner.clone()], ..Default::default() },
..Default::default()
};
@ -193,9 +184,7 @@ impl Device {
event_type: EventType,
content: Value,
) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner
.encrypt(&**self.verification_machine.store, event_type, content)
.await
self.inner.encrypt(&**self.verification_machine.store, event_type, content).await
}
/// Encrypt the given inbound group session as a forwarded room key for this
@ -254,9 +243,7 @@ impl UserDevices {
/// Returns true if there is at least one devices of this user that is
/// considered to be verified, false otherwise.
pub fn is_any_verified(&self) -> bool {
self.inner
.values()
.any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
self.inner.values().any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
}
/// Iterator over all the device ids of the user devices.
@ -341,8 +328,7 @@ impl ReadOnlyDevice {
/// Get the key of the given key algorithm belonging to this device.
pub fn get_key(&self, algorithm: DeviceKeyAlgorithm) -> Option<&String> {
self.keys
.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
self.keys.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
}
/// Get a map containing all the device keys.
@ -489,9 +475,8 @@ impl ReadOnlyDevice {
}
fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
let signing_key = self
.get_key(DeviceKeyAlgorithm::Ed25519)
.ok_or(SignatureError::MissingSigningKey)?;
let signing_key =
self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?;
let utility = Utility::new();
@ -634,10 +619,7 @@ pub(crate) mod test {
assert_eq!(device_id, device.device_id());
assert_eq!(device.algorithms.len(), 2);
assert_eq!(LocalTrust::Unset, device.local_trust_state());
assert_eq!(
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
assert_eq!(
device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(),
"xfgbLIC5WAl1OIkpOzoxpCe8FsRDT6nch7NQsOb15nc"
@ -652,10 +634,7 @@ pub(crate) mod test {
fn update_a_device() {
let mut device = get_device();
assert_eq!(
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
let display_name = "Alice's work computer".to_owned();

View file

@ -54,11 +54,7 @@ impl IdentityManager {
const MAX_KEY_QUERY_USERS: usize = 250;
pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self {
IdentityManager {
user_id,
device_id,
store,
}
IdentityManager { user_id, device_id, store }
}
fn user_id(&self) -> &UserId {
@ -78,9 +74,8 @@ impl IdentityManager {
&self,
response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
let changed_devices = self
.handle_devices_from_key_query(response.device_keys.clone())
.await?;
let changed_devices =
self.handle_devices_from_key_query(response.device_keys.clone()).await?;
let changed_identities = self.handle_cross_singing_keys(response).await?;
let changes = Changes {
@ -104,9 +99,8 @@ impl IdentityManager {
store: Store,
device_keys: DeviceKeys,
) -> StoreResult<DeviceChange> {
let old_device = store
.get_readonly_device(&device_keys.user_id, &device_keys.device_id)
.await?;
let old_device =
store.get_readonly_device(&device_keys.user_id, &device_keys.device_id).await?;
if let Some(mut device) = old_device {
if let Err(e) = device.update_device(&device_keys) {
@ -148,25 +142,20 @@ impl IdentityManager {
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map
.into_iter()
.filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
None
} else {
Some(spawn(Self::update_or_create_device(
store.clone(),
device_keys,
)))
}
});
let tasks = device_map.into_iter().filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
None
} else {
Some(spawn(Self::update_or_create_device(store.clone(), device_keys)))
}
});
let results = join_all(tasks).await;
@ -211,17 +200,15 @@ impl IdentityManager {
) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default();
let tasks = device_keys_map
.into_iter()
.map(|(user_id, device_keys_map)| {
spawn(Self::update_user_devices(
self.store.clone(),
self.user_id.clone(),
self.device_id.clone(),
user_id,
device_keys_map,
))
});
let tasks = device_keys_map.into_iter().map(|(user_id, device_keys_map)| {
spawn(Self::update_user_devices(
self.store.clone(),
self.user_id.clone(),
self.device_id.clone(),
user_id,
device_keys_map,
))
});
let results = join_all(tasks).await;
@ -254,10 +241,7 @@ impl IdentityManager {
let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) {
SelfSigningPubkey::from(s)
} else {
warn!(
"User identity for user {} didn't contain a self signing pubkey",
user_id
);
warn!("User identity for user {} didn't contain a self signing pubkey", user_id);
continue;
};
@ -276,13 +260,11 @@ impl IdentityManager {
continue;
};
identity
.update(master_key, self_signing, user_signing)
.map(|_| (i, false))
identity.update(master_key, self_signing, user_signing).map(|_| (i, false))
}
UserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| (i, false))
}
UserIdentities::Other(ref mut identity) => identity
.update(master_key, self_signing)
.map(|_| (i, false)),
}
} else if user_id == self.user_id() {
if let Some(s) = response.user_signing_keys.get(user_id) {
@ -310,10 +292,7 @@ impl IdentityManager {
continue;
}
} else if master_key.user_id() != user_id || self_signing.user_id() != user_id {
warn!(
"User id mismatch in one of the cross signing keys for user {}",
user_id
);
warn!("User id mismatch in one of the cross signing keys for user {}", user_id);
continue;
} else {
UserIdentity::new(master_key, self_signing)
@ -322,11 +301,7 @@ impl IdentityManager {
match result {
Ok((i, new)) => {
trace!(
"Updated or created new user identity for {}: {:?}",
user_id,
i
);
trace!("Updated or created new user identity for {}: {:?}", user_id, i);
if new {
changes.new.push(i);
} else {
@ -334,10 +309,7 @@ impl IdentityManager {
}
}
Err(e) => {
warn!(
"Couldn't update or create new user identity for {}: {:?}",
user_id, e
);
warn!("Couldn't update or create new user identity for {}: {:?}", user_id, e);
continue;
}
}
@ -635,10 +607,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0);
manager
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
manager.receive_keys_query_response(&other_key_query()).await.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1);
@ -649,12 +618,7 @@ pub(crate) mod test {
.await
.unwrap()
.unwrap();
let identity = manager
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok())
@ -667,10 +631,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0);
manager
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
manager.receive_keys_query_response(&other_key_query()).await.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1);
@ -681,12 +642,7 @@ pub(crate) mod test {
.await
.unwrap()
.unwrap();
let identity = manager
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok())

View file

@ -225,12 +225,7 @@ impl MasterPubkey {
&self,
subkey: impl Into<CrossSigningSubKeys<'a>>,
) -> Result<(), SignatureError> {
let (key_id, key) = self
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
let key_id = DeviceKeyId::try_from(key_id.as_str())?;
@ -287,12 +282,7 @@ impl UserSigningPubkey {
&self,
master_key: &MasterPubkey,
) -> Result<(), SignatureError> {
let (key_id, key) = self
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK.
@ -335,12 +325,7 @@ impl SelfSigningPubkey {
/// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed.
pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> {
let (key_id, key) = self
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK.
@ -472,37 +457,16 @@ impl UserIdentity {
) -> Result<Self, SignatureError> {
master_key.verify_subkey(&self_signing_key)?;
Ok(Self {
user_id: Arc::new(master_key.0.user_id.clone()),
master_key,
self_signing_key,
})
Ok(Self { user_id: Arc::new(master_key.0.user_id.clone()), master_key, self_signing_key })
}
#[cfg(test)]
pub async fn from_private(identity: &PrivateCrossSigningIdentity) -> Self {
let master_key = identity
.master_key
.lock()
.await
.as_ref()
.unwrap()
.public_key
.clone();
let self_signing_key = identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key
.clone();
let master_key = identity.master_key.lock().await.as_ref().unwrap().public_key.clone();
let self_signing_key =
identity.self_signing_key.lock().await.as_ref().unwrap().public_key.clone();
Self {
user_id: Arc::new(identity.user_id().clone()),
master_key,
self_signing_key,
}
Self { user_id: Arc::new(identity.user_id().clone()), master_key, self_signing_key }
}
/// Get the user id of this identity.
@ -644,8 +608,7 @@ impl OwnUserIdentity {
/// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed.
pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> {
self.user_signing_key
.verify_master_key(&identity.master_key)
self.user_signing_key.verify_master_key(&identity.master_key)
}
/// Check if the given device has been signed by this identity.
@ -790,9 +753,8 @@ pub(crate) mod test {
assert!(identity.is_device_signed(&first).is_err());
assert!(identity.is_device_signed(&second).is_ok());
let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
second.user_id().clone(),
)));
let private_identity =
Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(second.user_id().clone())));
let verification_machine = VerificationMachine::new(
ReadOnlyAccount::new(second.user_id(), second.device_id()),
private_identity.clone(),

View file

@ -105,10 +105,8 @@ impl WaitQueue {
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Vec<(
(UserId, DeviceIdBox, String),
ToDeviceEvent<RoomKeyRequestToDeviceEventContent>,
)> {
) -> Vec<((UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>)>
{
self.requests_ids_waiting
.remove(&(user_id.to_owned(), device_id.into()))
.map(|(_, request_ids)| {
@ -204,12 +202,7 @@ fn wrap_key_request_content(
Ok(OutgoingRequest {
request_id: id,
request: Arc::new(
ToDeviceRequest {
event_type: EventType::RoomKeyRequest,
txn_id: id,
messages,
}
.into(),
ToDeviceRequest { event_type: EventType::RoomKeyRequest, txn_id: id, messages }.into(),
),
})
}
@ -241,10 +234,7 @@ impl KeyRequestMachine {
.await?
.into_iter()
.filter(|i| !i.sent_out)
.map(|info| {
info.to_request(self.device_id())
.map_err(CryptoStoreError::from)
})
.map(|info| info.to_request(self.device_id()).map_err(CryptoStoreError::from))
.collect()
}
@ -262,11 +252,8 @@ impl KeyRequestMachine {
&self,
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> = self
.outgoing_to_device_requests
.iter()
.map(|i| i.value().clone())
.collect();
let key_forwards: Vec<OutgoingRequest> =
self.outgoing_to_device_requests.iter().map(|i| i.value().clone()).collect();
key_requests.extend(key_forwards);
Ok(key_requests)
@ -281,8 +268,7 @@ impl KeyRequestMachine {
let device_id = event.content.requesting_device_id.clone();
let request_id = event.content.request_id.clone();
self.incoming_key_requests
.insert((sender, device_id, request_id), event.clone());
self.incoming_key_requests.insert((sender, device_id, request_id), event.clone());
}
/// Handle all the incoming key requests that are queued up and empty our
@ -401,10 +387,8 @@ impl KeyRequestMachine {
return Ok(None);
};
let device = self
.store
.get_device(&event.sender, &event.content.requesting_device_id)
.await?;
let device =
self.store.get_device(&event.sender, &event.content.requesting_device_id).await?;
if let Some(device) = device {
match self.should_share_key(&device, &session).await {
@ -461,30 +445,22 @@ impl KeyRequestMachine {
device: &Device,
message_index: Option<u32>,
) -> OlmResult<Session> {
let (used_session, content) = device
.encrypt_session(session.clone(), message_index)
.await?;
let (used_session, content) =
device.encrypt_session(session.clone(), message_index).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
messages
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest {
request_id: id,
request: Arc::new(
ToDeviceRequest {
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
}
.into(),
ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages }
.into(),
),
};
@ -542,8 +518,8 @@ impl KeyRequestMachine {
} else {
Err(KeyshareDecision::OutboundSessionNotShared)
}
// Else just check if it's one of our own devices that requested the key and
// check if the device is trusted.
// Else just check if it's one of our own devices that requested the key
// and check if the device is trusted.
} else if device.user_id() == self.user_id() {
own_device_check()
// Otherwise, there's not enough info to decide if we can safely share
@ -711,9 +687,7 @@ impl KeyRequestMachine {
/// Delete the given outgoing key info.
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
self.store
.delete_outgoing_key_request(info.request_id)
.await
self.store.delete_outgoing_key_request(info.request_id).await
}
/// Mark the outgoing request as sent.
@ -736,20 +710,15 @@ impl KeyRequestMachine {
/// This will queue up a request cancelation.
async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
// TODO perhaps only remove the key info if the first known index is 0.
trace!(
"Successfully received a forwarded room key for {:#?}",
key_info
);
trace!("Successfully received a forwarded room key for {:#?}", key_info);
self.outgoing_to_device_requests
.remove(&key_info.request_id);
self.outgoing_to_device_requests.remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction.
self.delete_key_info(&key_info).await?;
let request = key_info.to_cancelation(self.device_id())?;
self.outgoing_to_device_requests
.insert(request.request_id, request);
self.outgoing_to_device_requests.insert(request.request_id, request);
Ok(())
}
@ -801,10 +770,7 @@ impl KeyRequestMachine {
);
}
Ok((
Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())),
session,
))
Ok((Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())), session))
} else {
info!(
"Received a forwarded room key from {}, but no key info was found.",
@ -919,11 +885,7 @@ mod test {
async fn create_machine() {
let machine = get_machine().await;
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
}
#[async_test]
@ -931,16 +893,10 @@ mod test {
let machine = get_machine().await;
let account = account();
let (_, session) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, session) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
let (cancel, request) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id())
.await
@ -948,10 +904,7 @@ mod test {
assert!(cancel.is_none());
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
let (cancel, _) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id())
@ -972,16 +925,10 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, session) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
machine
.create_outgoing_key_request(
session.room_id(),
@ -990,15 +937,8 @@ mod test {
)
.await
.unwrap();
assert!(!machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert_eq!(
machine.outgoing_to_device_requests().await.unwrap().len(),
1
);
assert!(!machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert_eq!(machine.outgoing_to_device_requests().await.unwrap().len(), 1);
machine
.create_outgoing_key_request(
@ -1014,15 +954,8 @@ mod test {
let request = requests.get(0).unwrap();
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
}
#[async_test]
@ -1037,10 +970,8 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, session) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
machine
.create_outgoing_key_request(
session.room_id(),
@ -1060,10 +991,7 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent {
sender: alice_id(),
content,
};
let mut event = ToDeviceEvent { sender: alice_id(), content };
assert!(
machine
@ -1078,19 +1006,13 @@ mod test {
.is_none()
);
let (_, first_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let (_, first_session) =
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
let first_session = first_session.unwrap();
assert_eq!(first_session.first_known_index(), 10);
machine
.store
.save_inbound_group_sessions(&[first_session.clone()])
.await
.unwrap();
machine.store.save_inbound_group_sessions(&[first_session.clone()]).await.unwrap();
// Get the cancel request.
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
@ -1110,24 +1032,16 @@ mod test {
let requests = machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0];
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
let export = session.export_at_index(15).await;
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent {
sender: alice_id(),
content,
};
let mut event = ToDeviceEvent { sender: alice_id(), content };
let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let (_, second_session) =
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
assert!(second_session.is_none());
@ -1135,15 +1049,10 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent {
sender: alice_id(),
content,
};
let mut event = ToDeviceEvent { sender: alice_id(), content };
let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let (_, second_session) =
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
assert_eq!(second_session.unwrap().first_known_index(), 0);
}
@ -1153,17 +1062,11 @@ mod test {
let machine = get_machine().await;
let account = account();
let own_device = machine
.store
.get_device(&alice_id(), &alice_device_id())
.await
.unwrap()
.unwrap();
let own_device =
machine.store.get_device(&alice_id(), &alice_device_id()).await.unwrap().unwrap();
let (outbound, inbound) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (outbound, inbound) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
// We don't share keys with untrusted devices.
assert_eq!(
@ -1175,20 +1078,13 @@ mod test {
);
own_device.set_trust_state(LocalTrust::Verified);
// Now we do want to share the keys.
assert!(machine
.should_share_key(&own_device, &inbound)
.await
.is_ok());
assert!(machine.should_share_key(&own_device, &inbound).await.is_ok());
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
machine.store.save_devices(&[bob_device]).await.unwrap();
let bob_device = machine
.store
.get_device(&bob_id(), &bob_device_id())
.await
.unwrap()
.unwrap();
let bob_device =
machine.store.get_device(&bob_id(), &bob_device_id()).await.unwrap().unwrap();
// We don't share sessions with other user's devices if no outbound
// session was provided.
@ -1231,17 +1127,12 @@ mod test {
// We now share the session, since it was shared before.
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
assert!(machine
.should_share_key(&bob_device, &inbound)
.await
.is_ok());
assert!(machine.should_share_key(&bob_device, &inbound).await.is_ok());
// But we don't share some other session that doesn't match our outbound
// session
let (_, other_inbound) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, other_inbound) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
assert_eq!(
machine
@ -1255,10 +1146,7 @@ mod test {
#[async_test]
async fn key_share_cycle() {
let alice_machine = get_machine().await;
let alice_account = Account {
inner: account(),
store: alice_machine.store.clone(),
};
let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
let bob_machine = bob_machine();
let bob_account = bob_account();
@ -1268,11 +1156,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified);
alice_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_devices(&[alice_device]).await.unwrap();
// Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1282,37 +1166,15 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session.
alice_machine
.store
.save_sessions(&[alice_session])
.await
.unwrap();
alice_machine
.store
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
alice_machine.store.save_devices(&[bob_device]).await.unwrap();
bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
bob_machine.store.save_devices(&[alice_device]).await.unwrap();
let (group_session, inbound_group_session) = bob_account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (group_session, inbound_group_session) =
bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
bob_machine
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
// Alice wants to request the outbound group session from bob.
alice_machine
@ -1326,9 +1188,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store.
bob_machine
.outbound_group_sessions
.insert(group_session.clone());
bob_machine.outbound_group_sessions.insert(group_session.clone());
// Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1346,15 +1206,9 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap();
alice_machine
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: alice_id(),
content,
};
let event = ToDeviceEvent { sender: alice_id(), content };
// Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests.is_empty());
@ -1383,10 +1237,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: bob_id(),
content,
};
let event = ToDeviceEvent { sender: bob_id(), content };
// Check that alice doesn't have the session.
assert!(alice_machine
@ -1407,11 +1258,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await
.unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
} else {
panic!("Invalid decrypted event type");
}
@ -1434,10 +1281,7 @@ mod test {
#[async_test]
async fn key_share_cycle_without_session() {
let alice_machine = get_machine().await;
let alice_account = Account {
inner: account(),
store: alice_machine.store.clone(),
};
let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
let bob_machine = bob_machine();
let bob_account = bob_account();
@ -1447,11 +1291,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified);
alice_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_devices(&[alice_device]).await.unwrap();
// Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1461,27 +1301,13 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session.
alice_machine
.store
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_devices(&[bob_device]).await.unwrap();
bob_machine.store.save_devices(&[alice_device]).await.unwrap();
let (group_session, inbound_group_session) = bob_account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (group_session, inbound_group_session) =
bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
bob_machine
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
// Alice wants to request the outbound group session from bob.
alice_machine
@ -1495,9 +1321,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store.
bob_machine
.outbound_group_sessions
.insert(group_session.clone());
bob_machine.outbound_group_sessions.insert(group_session.clone());
// Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1515,22 +1339,12 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap();
alice_machine
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: alice_id(),
content,
};
let event = ToDeviceEvent { sender: alice_id(), content };
// Bob doesn't have any outgoing requests.
assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(bob_machine.users_for_key_claim.is_empty());
assert!(bob_machine.wait_queue.is_empty());
@ -1538,35 +1352,19 @@ mod test {
bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob doens't have an outgoing requests since we're lacking a session.
assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(!bob_machine.users_for_key_claim.is_empty());
assert!(!bob_machine.wait_queue.is_empty());
// We create a session now.
alice_machine
.store
.save_sessions(&[alice_session])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
bob_machine.retry_keyshare(&alice_id(), &alice_device_id());
assert!(bob_machine.users_for_key_claim.is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob now has an outgoing requests.
assert!(!bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(!bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(bob_machine.wait_queue.is_empty());
// Get the request and convert it to a encrypted to-device event.
@ -1588,10 +1386,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: bob_id(),
content,
};
let event = ToDeviceEvent { sender: bob_id(), content };
// Check that alice doesn't have the session.
assert!(alice_machine
@ -1612,11 +1407,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await
.unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
} else {
panic!("Invalid decrypted event type");
}

View file

@ -148,19 +148,12 @@ impl OlmMachine {
let store = Arc::new(store);
let verification_machine =
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
let store = Store::new(
user_id.clone(),
user_identity.clone(),
store,
verification_machine.clone(),
);
let store =
Store::new(user_id.clone(), user_identity.clone(), store, verification_machine.clone());
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
let users_for_key_claim = Arc::new(DashMap::new());
let account = Account {
inner: account,
store: store.clone(),
};
let account = Account { inner: account, store: store.clone() };
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
@ -244,9 +237,7 @@ impl OlmMachine {
}
};
Ok(OlmMachine::new_helper(
&user_id, device_id, store, account, identity,
))
Ok(OlmMachine::new_helper(&user_id, device_id, store, account, identity))
}
/// Create a new machine with the default crypto store.
@ -296,19 +287,16 @@ impl OlmMachine {
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
let mut requests = Vec::new();
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
}) {
if let Some(r) = self
.keys_for_upload()
.await
.map(|r| OutgoingRequest { request_id: Uuid::new_v4(), request: Arc::new(r.into()) })
{
requests.push(r);
}
for request in self
.identity_manager
.users_for_key_query()
.await
.into_iter()
.map(|r| OutgoingRequest {
for request in
self.identity_manager.users_for_key_query().await.into_iter().map(|r| OutgoingRequest {
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
})
@ -318,12 +306,7 @@ impl OlmMachine {
requests.append(&mut self.outgoing_to_device_requests());
requests.append(&mut self.verification_machine.outgoing_room_message_requests());
requests.append(
&mut self
.key_request_machine
.outgoing_to_device_requests()
.await?,
);
requests.append(&mut self.key_request_machine.outgoing_to_device_requests().await?);
Ok(requests)
}
@ -374,10 +357,7 @@ impl OlmMachine {
let identity = self.user_identity.lock().await;
identity.mark_as_shared();
let changes = Changes {
private_identity: Some(identity.clone()),
..Default::default()
};
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
self.store.save_changes(changes).await
}
@ -407,10 +387,7 @@ impl OlmMachine {
);
let changes = Changes {
identities: IdentityChanges {
new: vec![public.into()],
..Default::default()
},
identities: IdentityChanges { new: vec![public.into()], ..Default::default() },
private_identity: Some(identity.clone()),
..Default::default()
};
@ -422,10 +399,8 @@ impl OlmMachine {
info!("Trying to upload the existing cross signing identity");
let request = identity.as_upload_request().await;
// TODO remove this expect.
let signature_request = identity
.sign_account(&self.account)
.await
.expect("Can't sign device keys");
let signature_request =
identity.sign_account(&self.account).await.expect("Can't sign device keys");
Ok((request, signature_request))
}
}
@ -519,9 +494,7 @@ impl OlmMachine {
///
/// * `response` - The response containing the claimed one-time keys.
async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
self.session_manager
.receive_keys_claim_response(response)
.await
self.session_manager.receive_keys_claim_response(response).await
}
/// Receive a successful keys query response.
@ -537,9 +510,7 @@ impl OlmMachine {
&self,
response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager
.receive_keys_query_response(response)
.await
self.identity_manager.receive_keys_query_response(response).await
}
/// Get a request to upload E2EE keys to the server.
@ -677,9 +648,7 @@ impl OlmMachine {
/// Returns true if a session was invalidated, false if there was no session
/// to invalidate.
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
self.group_session_manager
.invalidate_group_session(room_id)
.await
self.group_session_manager.invalidate_group_session(room_id).await
}
/// Get to-device requests to share a group session with users in a room.
@ -696,9 +665,7 @@ impl OlmMachine {
users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.group_session_manager
.share_group_session(room_id, users, encryption_settings)
.await
self.group_session_manager.share_group_session(room_id, users, encryption_settings).await
}
/// Receive and properly handle a decrypted to-device event.
@ -717,18 +684,15 @@ impl OlmMachine {
let event = match decrypted.event.deserialize() {
Ok(e) => e,
Err(e) => {
warn!(
"Decrypted to-device event failed to be parsed correctly {:?}",
e
);
warn!("Decrypted to-device event failed to be parsed correctly {:?}", e);
return Ok((None, None));
}
};
match event {
AnyToDeviceEvent::RoomKey(mut e) => Ok(self
.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e)
.await?),
AnyToDeviceEvent::RoomKey(mut e) => {
Ok(self.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e).await?)
}
AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self
.key_request_machine
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
@ -754,14 +718,9 @@ impl OlmMachine {
/// Mark an outgoing to-device requests as sent.
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
self.verification_machine.mark_request_as_sent(request_id);
self.key_request_machine
.mark_outgoing_request_as_sent(*request_id)
.await?;
self.group_session_manager
.mark_request_as_sent(request_id)
.await?;
self.session_manager
.mark_outgoing_request_as_sent(request_id);
self.key_request_machine.mark_outgoing_request_as_sent(*request_id).await?;
self.group_session_manager.mark_request_as_sent(request_id).await?;
self.session_manager.mark_outgoing_request_as_sent(request_id);
Ok(())
}
@ -810,10 +769,8 @@ impl OlmMachine {
// Always save the account, a new session might get created which also
// touches the account.
let mut changes = Changes {
account: Some(self.account.inner.clone()),
..Default::default()
};
let mut changes =
Changes { account: Some(self.account.inner.clone()), ..Default::default() };
self.update_one_time_key_count(one_time_keys_counts).await;
@ -830,10 +787,7 @@ impl OlmMachine {
Ok(e) => e,
Err(e) => {
// Skip invalid events.
warn!(
"Received an invalid to-device event {:?} {:?}",
e, raw_event
);
warn!("Received an invalid to-device event {:?} {:?}", e, raw_event);
continue;
}
};
@ -845,10 +799,7 @@ impl OlmMachine {
let decrypted = match self.decrypt_to_device_event(&e).await {
Ok(e) => e,
Err(err) => {
warn!(
"Failed to decrypt to-device event from {} {}",
e.sender, err
);
warn!("Failed to decrypt to-device event from {} {}", e.sender, err);
if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self
@ -903,10 +854,7 @@ impl OlmMachine {
events.push(raw_event);
}
let changed_sessions = self
.key_request_machine
.collect_incoming_key_requests()
.await?;
let changed_sessions = self.key_request_machine.collect_incoming_key_requests().await?;
changes.sessions.extend(changed_sessions);
@ -1023,25 +971,16 @@ impl OlmMachine {
// TODO check if this is from a verified device.
let (decrypted_event, _) = session.decrypt(event).await?;
trace!(
"Successfully decrypted a Megolm event {:?}",
decrypted_event
);
trace!("Successfully decrypted a Megolm event {:?}", decrypted_event);
if let Ok(e) = decrypted_event.deserialize() {
self.verification_machine
.receive_room_event(room_id, &e)
.await?;
self.verification_machine.receive_room_event(room_id, &e).await?;
}
let encryption_info = self
.get_encryption_info(&session, &event.sender, &content.device_id)
.await?;
let encryption_info =
self.get_encryption_info(&session, &event.sender, &content.device_id).await?;
Ok(SyncRoomEvent {
encryption_info: Some(encryption_info),
event: decrypted_event,
})
Ok(SyncRoomEvent { encryption_info: Some(encryption_info), event: decrypted_event })
}
/// Update the tracked users.
@ -1197,17 +1136,11 @@ impl OlmMachine {
let num_sessions = sessions.len();
let changes = Changes {
inbound_group_sessions: sessions,
..Default::default()
};
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
self.store.save_changes(changes).await?;
info!(
"Successfully imported {} inbound group sessions",
num_sessions
);
info!("Successfully imported {} inbound group sessions", num_sessions);
Ok((num_sessions, total_sessions))
}
@ -1318,10 +1251,7 @@ pub(crate) mod test {
}
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
}
fn keys_upload_response() -> upload_keys::Response {
@ -1340,15 +1270,7 @@ pub(crate) mod test {
let to_device_request = &requests[0];
let content: Raw<EncryptedEventContent> = serde_json::from_str(
to_device_request
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
to_device_request.messages.values().next().unwrap().values().next().unwrap().get(),
)
.unwrap();
@ -1358,15 +1280,9 @@ pub(crate) mod test {
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.account.inner.update_uploaded_key_count(0);
let request = machine
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload");
let response = keys_upload_response();
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
(machine, request.one_time_keys.unwrap())
}
@ -1375,10 +1291,7 @@ pub(crate) mod test {
let (machine, otk) = get_prepared_machine().await;
let response = keys_query_response();
machine
.receive_keys_query_response(&response)
.await
.unwrap();
machine.receive_keys_query_response(&response).await.unwrap();
(machine, otk)
}
@ -1421,28 +1334,15 @@ pub(crate) mod test {
async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) {
let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
let (session, content) = bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap();
let (session, content) = bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap();
alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent {
sender: alice.user_id().clone(),
content,
};
let event = ToDeviceEvent { sender: alice.user_id().clone(), content };
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
(alice, bob)
}
@ -1458,34 +1358,18 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id());
let mut response = keys_upload_response();
response
.one_time_key_counts
.remove(&DeviceKeyAlgorithm::SignedCurve25519)
.unwrap();
response.one_time_key_counts.remove(&DeviceKeyAlgorithm::SignedCurve25519).unwrap();
assert!(machine.should_upload_keys().await);
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.should_upload_keys().await);
response
.one_time_key_counts
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.should_upload_keys().await);
response
.one_time_key_counts
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(!machine.should_upload_keys().await);
}
@ -1497,20 +1381,12 @@ pub(crate) mod test {
assert!(machine.should_upload_keys().await);
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.should_upload_keys().await);
assert!(machine.account.generate_one_time_keys().await.is_ok());
response
.one_time_key_counts
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.account.generate_one_time_keys().await.is_err());
}
@ -1537,14 +1413,8 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id());
let room_id = room_id!("!test:example.org");
machine
.create_outbound_group_session_with_defaults(&room_id)
.await
.unwrap();
assert!(machine
.group_session_manager
.get_outbound_group_session(&room_id)
.is_some());
machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
assert!(machine.group_session_manager.get_outbound_group_session(&room_id).is_some());
machine.invalidate_group_session(&room_id).await.unwrap();
@ -1600,10 +1470,8 @@ pub(crate) mod test {
let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519();
let mut request = machine
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
let mut request =
machine.keys_for_upload().await.expect("Can't prepare initial key upload");
let utility = Utility::new();
let ret = utility.verify_json(
@ -1626,15 +1494,10 @@ pub(crate) mod test {
let mut response = keys_upload_response();
response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519,
(request.one_time_keys.unwrap().len() as u64)
.try_into()
.unwrap(),
(request.one_time_keys.unwrap().len() as u64).try_into().unwrap(),
);
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
let ret = machine.keys_for_upload().await;
assert!(ret.is_none());
@ -1650,17 +1513,9 @@ pub(crate) mod test {
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
assert!(alice_devices.devices().peekable().peek().is_none());
machine
.receive_keys_query_response(&response)
.await
.unwrap();
machine.receive_keys_query_response(&response).await.unwrap();
let device = machine
.store
.get_device(&alice_id, alice_device_id)
.await
.unwrap()
.unwrap();
let device = machine.store.get_device(&alice_id, alice_device_id).await.unwrap().unwrap();
assert_eq!(device.user_id(), &alice_id);
assert_eq!(device.device_id(), alice_device_id);
}
@ -1672,11 +1527,8 @@ pub(crate) mod test {
let alice = alice_id();
let alice_device = alice_device_id();
let (_, missing_sessions) = machine
.get_missing_sessions(&mut [alice.clone()].iter())
.await
.unwrap()
.unwrap();
let (_, missing_sessions) =
machine.get_missing_sessions(&mut [alice.clone()].iter()).await.unwrap().unwrap();
assert!(missing_sessions.one_time_keys.contains_key(&alice));
let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap();
@ -1699,10 +1551,7 @@ pub(crate) mod test {
let response = claim_keys::Response::new(one_time_keys);
alice_machine
.receive_keys_claim_response(&response)
.await
.unwrap();
alice_machine.receive_keys_claim_response(&response).await.unwrap();
let session = alice_machine
.store
@ -1718,28 +1567,14 @@ pub(crate) mod test {
async fn test_olm_encryption() {
let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
let event = ToDeviceEvent {
sender: alice.user_id().clone(),
content: bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap()
.1,
content: bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap().1,
};
let event = bob
.decrypt_to_device_event(&event)
.await
.unwrap()
.event
.deserialize()
.unwrap();
let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap();
if let AnyToDeviceEvent::Dummy(e) = event {
assert_eq!(&e.sender, alice.user_id());
@ -1768,17 +1603,12 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests),
};
let alice_session = alice
.group_session_manager
.get_outbound_group_session(&room_id)
.unwrap();
let alice_session =
alice.group_session_manager.get_outbound_group_session(&room_id).unwrap();
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
bob.store
.save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()])
.await
@ -1823,25 +1653,16 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests),
};
let group_session = bob
.decrypt_to_device_event(&event)
.await
.unwrap()
.inbound_group_session;
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap();
let group_session =
bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session;
bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
let plaintext = "It is a secret to everybody";
let content = MessageEventContent::text_plain(plaintext);
let encrypted_content = alice
.encrypt(
&room_id,
AnyMessageEventContent::RoomMessage(content.clone()),
)
.encrypt(&room_id, AnyMessageEventContent::RoomMessage(content.clone()))
.await
.unwrap();
@ -1853,13 +1674,8 @@ pub(crate) mod test {
unsigned: Unsigned::default(),
};
let decrypted_event = bob
.decrypt_room_event(&event, &room_id)
.await
.unwrap()
.event
.deserialize()
.unwrap();
let decrypted_event =
bob.decrypt_room_event(&event, &room_id).await.unwrap().event.deserialize().unwrap();
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent {
sender,
@ -1898,10 +1714,7 @@ pub(crate) mod test {
let device_id = machine.device_id().to_owned();
let ed25519_key = machine.identity_keys().ed25519().to_owned();
machine
.receive_keys_upload_response(&keys_upload_response())
.await
.unwrap();
machine.receive_keys_upload_response(&keys_upload_response()).await.unwrap();
drop(machine);
@ -1923,11 +1736,7 @@ pub(crate) mod test {
async fn interactive_verification() {
let (alice, bob) = get_machine_pair_with_setup_sessions().await;
let bob_device = alice
.get_device(bob.user_id(), bob.device_id())
.await
.unwrap()
.unwrap();
let bob_device = alice.get_device(bob.user_id(), bob.device_id()).await.unwrap().unwrap();
assert!(!bob_device.is_trusted());
@ -1941,10 +1750,7 @@ pub(crate) mod test {
assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none());
let event = bob_sas
.accept()
.map(|r| request_to_event(bob.user_id(), &r))
.unwrap();
let event = bob_sas.accept().map(|r| request_to_event(bob.user_id(), &r)).unwrap();
alice.handle_verification_event(&event).await;
@ -1991,11 +1797,8 @@ pub(crate) mod test {
assert!(alice_sas.is_done());
assert!(bob_device.is_trusted());
let alice_device = bob
.get_device(alice.user_id(), alice.device_id())
.await
.unwrap()
.unwrap();
let alice_device =
bob.get_device(alice.user_id(), alice.device_id()).await.unwrap().unwrap();
assert!(!alice_device.is_trusted());
bob.handle_verification_event(&event).await;

View file

@ -139,10 +139,8 @@ impl Account {
// Try to find a ciphertext that was meant for our device.
if let Some(ciphertext) = own_ciphertext {
let message_type: u8 = ciphertext
.message_type
.try_into()
.map_err(|_| EventError::UnsupportedOlmType)?;
let message_type: u8 =
ciphertext.message_type.try_into().map_err(|_| EventError::UnsupportedOlmType)?;
let sha = Sha256::new()
.chain(&content.sender_key)
@ -160,20 +158,18 @@ impl Account {
.map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it.
let (session, event, signing_key) = match self
.decrypt_olm_message(&event.sender, &content.sender_key, message)
.await
{
Ok(d) => d,
Err(OlmError::SessionWedged(user_id, sender_key)) => {
if self.store.is_message_known(&message_hash).await? {
return Err(OlmError::ReplayedMessage(user_id, sender_key));
} else {
return Err(OlmError::SessionWedged(user_id, sender_key));
let (session, event, signing_key) =
match self.decrypt_olm_message(&event.sender, &content.sender_key, message).await {
Ok(d) => d,
Err(OlmError::SessionWedged(user_id, sender_key)) => {
if self.store.is_message_known(&message_hash).await? {
return Err(OlmError::ReplayedMessage(user_id, sender_key));
} else {
return Err(OlmError::SessionWedged(user_id, sender_key));
}
}
}
Err(e) => return Err(e),
};
Err(e) => return Err(e),
};
debug!("Decrypted a to-device event {:?}", event);
@ -208,9 +204,8 @@ impl Account {
}
self.inner.mark_as_shared();
let one_time_key_count = response
.one_time_key_counts
.get(&DeviceKeyAlgorithm::SignedCurve25519);
let one_time_key_count =
response.one_time_key_counts.get(&DeviceKeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
debug!(
@ -295,9 +290,8 @@ impl Account {
message: OlmMessage,
) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session.
let (session, plaintext) = if let Some(d) = self
.try_decrypt_olm_message(sender, sender_key, &message)
.await?
let (session, plaintext) = if let Some(d) =
self.try_decrypt_olm_message(sender, sender_key, &message).await?
{
// Decryption succeeded, de-structure the session/plaintext out of
// the Option.
@ -314,32 +308,26 @@ impl Account {
available sessions {} {}",
sender, sender_key
);
return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
return Err(OlmError::SessionWedged(sender.to_owned(), sender_key.to_owned()));
}
OlmMessage::PreKey(m) => {
// Create the new session.
let session = match self
.inner
.create_inbound_session(sender_key, m.clone())
.await
{
Ok(s) => s,
Err(e) => {
warn!(
"Failed to create a new Olm session for {} {}
let session =
match self.inner.create_inbound_session(sender_key, m.clone()).await {
Ok(s) => s,
Err(e) => {
warn!(
"Failed to create a new Olm session for {} {}
from a prekey message: {}",
sender, sender_key, e
);
return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
}
};
sender, sender_key, e
);
return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
}
};
session
}
@ -426,9 +414,8 @@ impl Account {
return Err(EventError::MissmatchedKeys.into());
}
let signing_key = keys
.get(&DeviceKeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?;
let signing_key =
keys.get(&DeviceKeyAlgorithm::Ed25519).ok_or(EventError::MissingSigningKey)?;
Ok((
Raw::from(serde_json::from_value::<AnyToDeviceEvent>(decrypted_json)?),
@ -545,8 +532,7 @@ impl ReadOnlyAccount {
/// * `new_count` - The new count that was reported by the server.
pub(crate) fn update_uploaded_key_count(&self, new_count: u64) {
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
self.uploaded_signed_key_count
.store(key_count, Ordering::Relaxed);
self.uploaded_signed_key_count.store(key_count, Ordering::Relaxed);
}
/// Get the currently known uploaded key count.
@ -629,19 +615,12 @@ impl ReadOnlyAccount {
/// Returns None if no keys need to be uploaded.
pub(crate) async fn keys_for_upload(
&self,
) -> Option<(
Option<DeviceKeys>,
Option<BTreeMap<DeviceKeyId, OneTimeKey>>,
)> {
) -> Option<(Option<DeviceKeys>, Option<BTreeMap<DeviceKeyId, OneTimeKey>>)> {
if !self.should_upload_keys().await {
return None;
}
let device_keys = if !self.shared() {
Some(self.device_keys().await)
} else {
None
};
let device_keys = if !self.shared() { Some(self.device_keys().await) } else { None };
let one_time_keys = self.signed_one_time_keys().await.ok();
@ -664,7 +643,8 @@ impl ReadOnlyAccount {
///
/// # Arguments
///
/// * `pickle_mode` - The mode that was used to pickle the account, either an
/// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledAccount {
let pickle = AccountPickle(self.inner.lock().await.pickle(pickle_mode));
@ -684,7 +664,8 @@ impl ReadOnlyAccount {
///
/// * `pickle` - The pickled version of the Account.
///
/// * `pickle_mode` - The mode that was used to pickle the account, either an
/// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase.
pub fn from_pickle(
pickle: PickledAccount,
@ -740,25 +721,17 @@ impl ReadOnlyAccount {
"keys": device_keys.keys,
});
device_keys
.signatures
.entry(self.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(json_device_keys).await,
);
device_keys.signatures.entry(self.user_id().clone()).or_insert_with(BTreeMap::new).insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(json_device_keys).await,
);
device_keys
}
pub(crate) async fn bootstrap_cross_signing(
&self,
) -> (
PrivateCrossSigningIdentity,
UploadSigningKeysRequest,
SignatureUploadRequest,
) {
) -> (PrivateCrossSigningIdentity, UploadSigningKeysRequest, SignatureUploadRequest) {
PrivateCrossSigningIdentity::new_with_account(self).await
}
@ -871,8 +844,8 @@ impl ReadOnlyAccount {
/// # Arguments
/// * `device` - The other account's device.
///
/// * `key_map` - A map from the algorithm and device id to the one-time key that the other
/// account created and shared with us.
/// * `key_map` - A map from the algorithm and device id to the one-time key
/// that the other account created and shared with us.
pub(crate) async fn create_outbound_session(
&self,
device: ReadOnlyDevice,
@ -909,24 +882,20 @@ impl ReadOnlyAccount {
)
})?;
let curve_key = device
.get_key(DeviceKeyAlgorithm::Curve25519)
.ok_or_else(|| {
SessionCreationError::DeviceMissingCurveKey(
device.user_id().to_owned(),
device.device_id().into(),
)
})?;
let curve_key = device.get_key(DeviceKeyAlgorithm::Curve25519).ok_or_else(|| {
SessionCreationError::DeviceMissingCurveKey(
device.user_id().to_owned(),
device.device_id().into(),
)
})?;
self.create_outbound_session_helper(curve_key, &one_time_key)
.await
.map_err(|e| {
SessionCreationError::OlmError(
device.user_id().to_owned(),
device.device_id().into(),
e,
)
})
self.create_outbound_session_helper(curve_key, &one_time_key).await.map_err(|e| {
SessionCreationError::OlmError(
device.user_id().to_owned(),
device.device_id().into(),
e,
)
})
}
/// Create a new session with another account given a pre-key Olm message.
@ -944,17 +913,10 @@ impl ReadOnlyAccount {
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<Session, OlmSessionError> {
let session = self
.inner
.lock()
.await
.create_inbound_session_from(their_identity_key, message)?;
let session =
self.inner.lock().await.create_inbound_session_from(their_identity_key, message)?;
self.inner
.lock()
.await
.remove_one_time_keys(&session)
.expect(
self.inner.lock().await.remove_one_time_keys(&session).expect(
"Session was successfully created but the account doesn't hold a matching one-time key",
);
@ -1026,8 +988,7 @@ impl ReadOnlyAccount {
&self,
room_id: &RoomId,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
self.create_group_session_pair(room_id, EncryptionSettings::default())
.await
self.create_group_session_pair(room_id, EncryptionSettings::default()).await
}
#[cfg(test)]
@ -1037,27 +998,19 @@ impl ReadOnlyAccount {
let device = ReadOnlyDevice::from_account(other).await;
let mut our_session = self
.create_outbound_session(device.clone(), &one_time)
.await
.unwrap();
let mut our_session =
self.create_outbound_session(device.clone(), &one_time).await.unwrap();
other.mark_keys_as_published().await;
let message = our_session
.encrypt(&device, EventType::Dummy, json!({}))
.await
.unwrap();
let message = our_session.encrypt(&device, EventType::Dummy, json!({})).await.unwrap();
let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme {
c
} else {
panic!("Invalid encrypted event algorithm");
};
let own_ciphertext = content
.ciphertext
.get(other.identity_keys.curve25519())
.unwrap();
let own_ciphertext = content.ciphertext.get(other.identity_keys.curve25519()).unwrap();
let message_type: u8 = own_ciphertext.message_type.try_into().unwrap();
let message =

View file

@ -147,10 +147,8 @@ impl InboundGroupSession {
forwarding_chains.push(sender_key.to_owned());
let mut sender_claimed_key = BTreeMap::new();
sender_claimed_key.insert(
DeviceKeyAlgorithm::Ed25519,
content.sender_claimed_ed25519_key.to_owned(),
);
sender_claimed_key
.insert(DeviceKeyAlgorithm::Ed25519, content.sender_claimed_ed25519_key.to_owned());
Ok(InboundGroupSession {
inner: Mutex::new(session).into(),
@ -217,11 +215,7 @@ impl InboundGroupSession {
let message_index = std::cmp::max(self.first_known_index(), message_index);
let session_key = ExportedGroupSessionKey(
self.inner
.lock()
.await
.export(message_index)
.expect("Can't export session"),
self.inner.lock().await.export(message_index).expect("Can't export session"),
);
ExportedRoomKey {
@ -314,9 +308,7 @@ impl InboundGroupSession {
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value
.as_object_mut()
.ok_or(EventError::NotAnObject)?;
let decrypted_object = decrypted_value.as_object_mut().ok_or(EventError::NotAnObject)?;
// TODO better number conversion here.
let server_ts = event
@ -335,10 +327,8 @@ impl InboundGroupSession {
serde_json::to_value(&event.unsigned).unwrap_or_default(),
);
if let Some(decrypted_content) = decrypted_object
.get_mut("content")
.map(|c| c.as_object_mut())
.flatten()
if let Some(decrypted_content) =
decrypted_object.get_mut("content").map(|c| c.as_object_mut()).flatten()
{
if !decrypted_content.contains_key("m.relates_to") {
if let Some(relation) = &event.content.relates_to {
@ -350,19 +340,14 @@ impl InboundGroupSession {
}
}
Ok((
serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?,
message_index,
))
Ok((serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?, message_index))
}
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for InboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InboundGroupSession")
.field("session_id", &self.session_id())
.finish()
f.debug_struct("InboundGroupSession").field("session_id", &self.session_id()).finish()
}
}

View file

@ -108,10 +108,8 @@ impl From<ForwardedRoomKeyToDeviceEventContent> for ExportedRoomKey {
/// Convert the content of a forwarded room key into a exported room key.
fn from(forwarded_key: ForwardedRoomKeyToDeviceEventContent) -> Self {
let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
sender_claimed_keys.insert(
DeviceKeyAlgorithm::Ed25519,
forwarded_key.sender_claimed_ed25519_key,
);
sender_claimed_keys
.insert(DeviceKeyAlgorithm::Ed25519, forwarded_key.sender_claimed_ed25519_key);
Self {
algorithm: forwarded_key.algorithm,
@ -143,10 +141,7 @@ mod test {
#[tokio::test]
#[cfg(target_os = "linux")]
async fn expiration() {
let settings = EncryptionSettings {
rotation_period_msgs: 1,
..Default::default()
};
let settings = EncryptionSettings { rotation_period_msgs: 1, ..Default::default() };
let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account
@ -156,9 +151,9 @@ mod test {
assert!(!session.expired());
let _ = session
.encrypt(AnyMessageEventContent::RoomMessage(
MessageEventContent::text_plain("Test message"),
))
.encrypt(AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
"Test message",
)))
.await;
assert!(session.expired());

View file

@ -96,12 +96,10 @@ impl EncryptionSettings {
/// Create new encryption settings using an `EncryptionEventContent` and a
/// history visibility.
pub fn new(content: EncryptionEventContent, history_visibility: HistoryVisibility) -> Self {
let rotation_period: Duration = content
.rotation_period_ms
.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
let rotation_period_msgs: u64 = content
.rotation_period_msgs
.map_or(ROTATION_MESSAGES, Into::into);
let rotation_period: Duration =
content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
let rotation_period_msgs: u64 =
content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
Self {
algorithm: content.algorithm,
@ -180,8 +178,7 @@ impl OutboundGroupSession {
request: Arc<ToDeviceRequest>,
message_index: u32,
) {
self.to_share_with_set
.insert(request_id, (request, message_index));
self.to_share_with_set.insert(request_id, (request, message_index));
}
/// This should be called if an the user wishes to rotate this session.
@ -219,10 +216,7 @@ impl OutboundGroupSession {
});
user_pairs.for_each(|(u, d)| {
self.shared_with_set
.entry(u)
.or_insert_with(DashMap::new)
.extend(d);
self.shared_with_set.entry(u).or_insert_with(DashMap::new).extend(d);
});
if self.to_share_with_set.is_empty() {
@ -235,11 +229,8 @@ impl OutboundGroupSession {
self.mark_as_shared();
}
} else {
let request_ids: Vec<String> = self
.to_share_with_set
.iter()
.map(|e| e.key().to_string())
.collect();
let request_ids: Vec<String> =
self.to_share_with_set.iter().map(|e| e.key().to_string()).collect();
error!(
all_request_ids = ?request_ids,
@ -290,11 +281,7 @@ impl OutboundGroupSession {
let relates_to: Option<Relation> = json_content
.get("content")
.map(|c| {
c.get("m.relates_to")
.cloned()
.map(|r| serde_json::from_value(r).ok())
})
.map(|c| c.get("m.relates_to").cloned().map(|r| serde_json::from_value(r).ok()))
.flatten()
.flatten();
@ -437,10 +424,7 @@ impl OutboundGroupSession {
/// Get the list of requests that need to be sent out for this session to be
/// marked as shared.
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
self.to_share_with_set
.iter()
.map(|i| i.value().0.clone())
.collect()
self.to_share_with_set.iter().map(|i| i.value().0.clone()).collect()
}
/// Get the list of request ids this session is waiting for to be sent out.
@ -455,11 +439,11 @@ impl OutboundGroupSession {
///
/// # Arguments
///
/// * `device_id` - The device id of the device that created this session. Put differently, our
/// own device id.
/// * `device_id` - The device id of the device that created this session.
/// Put differently, our own device id.
///
/// * `identity_keys` - The identity keys of the device that created this session, our own
/// identity keys.
/// * `identity_keys` - The identity keys of the device that created this
/// session, our own identity keys.
///
/// * `pickle` - The pickled version of the `OutboundGroupSession`.
///
@ -501,7 +485,8 @@ impl OutboundGroupSession {
///
/// # Arguments
///
/// * `pickle_mode` - The mode that should be used to pickle the group session,
/// * `pickle_mode` - The mode that should be used to pickle the group
/// session,
/// either an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickling_mode: PicklingMode) -> PickledOutboundGroupSession {
let pickle: OutboundGroupSessionPickle =
@ -522,10 +507,7 @@ impl OutboundGroupSession {
(
u.key().clone(),
#[allow(clippy::map_clone)]
u.value()
.iter()
.map(|d| (d.key().clone(), *d.value()))
.collect(),
u.value().iter().map(|d| (d.key().clone(), *d.value())).collect(),
)
})
.collect(),
@ -572,10 +554,7 @@ pub struct PickledOutboundGroupSession {
/// The room id this session is used for.
pub room_id: Arc<RoomId>,
/// The timestamp when this session was created.
#[serde(
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
pub creation_time: Instant,
/// The number of messages this session has already encrypted.
pub message_count: u64,

View file

@ -91,21 +91,12 @@ pub(crate) mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob
.one_time_keys()
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key =
bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice
.create_outbound_session_helper(&sender_key, &one_time_key)
.await
.unwrap();
let session =
alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
(alice, session)
}
@ -121,10 +112,7 @@ pub(crate) mod test {
assert_ne!(identity_keys.keys().len(), 0);
assert_ne!(identity_keys.iter().len(), 0);
assert!(identity_keys.contains_key("ed25519"));
assert_eq!(
identity_keys.ed25519(),
identity_keys.get("ed25519").unwrap()
);
assert_eq!(identity_keys.ed25519(), identity_keys.get("ed25519").unwrap());
assert!(!identity_keys.curve25519().is_empty());
account.mark_as_shared();
@ -148,10 +136,7 @@ pub(crate) mod test {
assert_ne!(one_time_keys.iter().len(), 0);
assert!(one_time_keys.contains_key("curve25519"));
assert_eq!(one_time_keys.curve25519().keys().len(), 10);
assert_eq!(
one_time_keys.curve25519(),
one_time_keys.get("curve25519").unwrap()
);
assert_eq!(one_time_keys.curve25519(), one_time_keys.get("curve25519").unwrap());
account.mark_keys_as_published().await;
let one_time_keys = account.one_time_keys().await;
@ -167,13 +152,7 @@ pub(crate) mod test {
let one_time_keys = alice.one_time_keys().await;
alice.mark_keys_as_published().await;
let one_time_key = one_time_keys
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = one_time_keys.curve25519().iter().next().unwrap().1.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
@ -197,10 +176,7 @@ pub(crate) mod test {
.await
.unwrap();
assert!(alice_session
.matches(bob_keys.curve25519(), prekey_message)
.await
.unwrap());
assert!(alice_session.matches(bob_keys.curve25519(), prekey_message).await.unwrap());
assert_eq!(bob_session.session_id(), alice_session.session_id());
@ -213,10 +189,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost");
let (outbound, _) = alice
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (outbound, _) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
@ -239,10 +212,7 @@ pub(crate) mod test {
let plaintext = "This is a secret to everybody".to_owned();
let ciphertext = outbound.encrypt_helper(plaintext.clone()).await;
assert_eq!(
plaintext,
inbound.decrypt_helper(ciphertext).await.unwrap().0
);
assert_eq!(plaintext, inbound.decrypt_helper(ciphertext).await.unwrap().0);
}
#[tokio::test]
@ -250,10 +220,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost");
let (_, inbound) = alice
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (_, inbound) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
let export = inbound.export().await;
let export: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();

View file

@ -118,10 +118,8 @@ impl Session {
.get_key(DeviceKeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?;
let relates_to = content
.get("m.relates_to")
.cloned()
.and_then(|v| serde_json::from_value(v).ok());
let relates_to =
content.get("m.relates_to").cloned().and_then(|v| serde_json::from_value(v).ok());
let payload = json!({
"sender": self.user_id.as_str(),
@ -171,10 +169,7 @@ impl Session {
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<bool, OlmSessionError> {
self.inner
.lock()
.await
.matches_inbound_session_from(their_identity_key, message)
self.inner.lock().await.matches_inbound_session_from(their_identity_key, message)
}
/// Returns the unique identifier for this session.
@ -256,16 +251,10 @@ pub struct PickledSession {
/// The curve25519 key of the other user that we share this session with.
pub sender_key: String,
/// The relative time elapsed since the session was created.
#[serde(
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
pub creation_time: Instant,
/// The relative time elapsed since the session was last used.
#[serde(
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
pub last_use_time: Instant,
}

View file

@ -185,10 +185,7 @@ impl PrivateCrossSigningIdentity {
signed_keys
.entry((&*self.user_id).to_owned())
.or_insert_with(BTreeMap::new)
.insert(
device_keys.device_id.to_string(),
serde_json::to_value(device_keys)?,
);
.insert(device_keys.device_id.to_string(), serde_json::to_value(device_keys)?);
Ok(SignatureUploadRequest::new(signed_keys))
}
@ -228,10 +225,7 @@ impl PrivateCrossSigningIdentity {
signature,
);
let master = MasterSigning {
inner: master,
public_key: public_key.into(),
};
let master = MasterSigning { inner: master, public_key: public_key.into() };
let identity = Self::new_helper(account.user_id(), master).await;
let signature_request = identity
@ -249,20 +243,14 @@ impl PrivateCrossSigningIdentity {
let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning);
master.sign_subkey(&mut public_key).await;
let user = UserSigning {
inner: user,
public_key: public_key.into(),
};
let user = UserSigning { inner: user, public_key: public_key.into() };
let self_signing = Signing::new();
let mut public_key =
self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning);
master.sign_subkey(&mut public_key).await;
let self_signing = SelfSigning {
inner: self_signing,
public_key: public_key.into(),
};
let self_signing = SelfSigning { inner: self_signing, public_key: public_key.into() };
Self {
user_id: Arc::new(user_id.to_owned()),
@ -280,10 +268,7 @@ impl PrivateCrossSigningIdentity {
let master = Signing::new();
let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master);
let master = MasterSigning {
inner: master,
public_key: public_key.into(),
};
let master = MasterSigning { inner: master, public_key: public_key.into() };
Self::new_helper(&user_id, master).await
}
@ -333,11 +318,7 @@ impl PrivateCrossSigningIdentity {
None
};
let pickle = PickledSignings {
master_key,
user_signing_key,
self_signing_key,
};
let pickle = PickledSignings { master_key, user_signing_key, self_signing_key };
let pickle = serde_json::to_string(&pickle)?;
@ -389,35 +370,16 @@ impl PrivateCrossSigningIdentity {
/// Get the upload request that is needed to share the public keys of this
/// identity.
pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest {
let master_key = self
.master_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let master_key =
self.master_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
let user_signing_key = self
.user_signing_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let user_signing_key =
self.user_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
let self_signing_key = self
.self_signing_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let self_signing_key =
self.self_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
UploadSigningKeysRequest {
master_key,
self_signing_key,
user_signing_key,
}
UploadSigningKeysRequest { master_key, self_signing_key, user_signing_key }
}
}
@ -480,28 +442,12 @@ mod test {
assert!(master_key
.public_key
.verify_subkey(
&identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.verify_subkey(&identity.self_signing_key.lock().await.as_ref().unwrap().public_key,)
.is_ok());
assert!(master_key
.public_key
.verify_subkey(
&identity
.user_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.verify_subkey(&identity.user_signing_key.lock().await.as_ref().unwrap().public_key,)
.is_ok());
}
@ -511,15 +457,11 @@ mod test {
let pickled = identity.pickle(pickle_key()).await.unwrap();
let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key())
.await
.unwrap();
let unpickled =
PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()).await.unwrap();
assert_eq!(identity.user_id, unpickled.user_id);
assert_eq!(
&*identity.master_key.lock().await,
&*unpickled.master_key.lock().await
);
assert_eq!(&*identity.master_key.lock().await, &*unpickled.master_key.lock().await);
assert_eq!(
&*identity.user_signing_key.lock().await,
&*unpickled.user_signing_key.lock().await
@ -590,9 +532,6 @@ mod test {
bob_public.master_key = master.into();
user_signing
.public_key
.verify_master_key(bob_public.master_key())
.unwrap();
user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap();
}
}

View file

@ -68,9 +68,7 @@ pub struct Signing {
impl std::fmt::Debug for Signing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Signing")
.field("public_key", &self.public_key.as_str())
.finish()
f.debug_struct("Signing").field("public_key", &self.public_key.as_str()).finish()
}
}
@ -151,10 +149,7 @@ impl MasterSigning {
) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
Ok(Self { inner, public_key: pickle.public_key.into() })
}
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
@ -195,10 +190,7 @@ impl UserSigning {
user: &UserIdentity,
) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> {
let user_master: &CrossSigningKey = user.master_key().as_ref();
let signature = self
.inner
.sign_json(serde_json::to_value(user_master)?)
.await?;
let signature = self.inner.sign_json(serde_json::to_value(user_master)?).await?;
let mut signatures = BTreeMap::new();
@ -223,10 +215,7 @@ impl UserSigning {
) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
Ok(Self { inner, public_key: pickle.public_key.into() })
}
}
@ -274,10 +263,7 @@ impl SelfSigning {
) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
Ok(Self { inner, public_key: pickle.public_key.into() })
}
}
@ -348,17 +334,12 @@ impl Signing {
getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object");
let nonce = GenericArray::from_slice(nonce.as_slice());
let ciphertext = cipher
.encrypt(nonce, self.seed.as_slice())
.expect("Can't encrypt signing pickle");
let ciphertext =
cipher.encrypt(nonce, self.seed.as_slice()).expect("Can't encrypt signing pickle");
let ciphertext = encode(ciphertext);
let pickle = InnerPickle {
version: 1,
nonce: encode(nonce.as_slice()),
ciphertext,
};
let pickle = InnerPickle { version: 1, nonce: encode(nonce.as_slice()), ciphertext };
PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
}
@ -371,11 +352,8 @@ impl Signing {
let mut keys = BTreeMap::new();
keys.insert(
DeviceKeyId::from_parts(
DeviceKeyAlgorithm::Ed25519,
self.public_key().as_str().into(),
)
.to_string(),
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.public_key().as_str().into())
.to_string(),
self.public_key().to_string(),
);

View file

@ -29,9 +29,7 @@ pub(crate) struct Utility {
impl Utility {
pub fn new() -> Self {
Self {
inner: OlmUtility::new(),
}
Self { inner: OlmUtility::new() }
}
/// Verify a signed JSON object.
@ -66,29 +64,20 @@ impl Utility {
let unsigned = json_object.remove("unsigned");
let signatures = json_object.remove("signatures");
let canonical_json: CanonicalJsonValue = json
.clone()
.try_into()
.map_err(|_| SignatureError::NotAnObject)?;
let canonical_json: CanonicalJsonValue =
json.clone().try_into().map_err(|_| SignatureError::NotAnObject)?;
let canonical_json: String = canonical_json.to_string();
let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?;
let signature_object = signatures
.as_object()
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature_object
.get(user_id.as_str())
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature
.get(key_id.to_string())
.ok_or(SignatureError::NoSignatureFound)?;
let signature_object = signatures.as_object().ok_or(SignatureError::NoSignatureFound)?;
let signature =
signature_object.get(user_id.as_str()).ok_or(SignatureError::NoSignatureFound)?;
let signature =
signature.get(key_id.to_string()).ok_or(SignatureError::NoSignatureFound)?;
let signature = signature.as_str().ok_or(SignatureError::NoSignatureFound)?;
let ret = match self
.inner
.ed25519_verify(signing_key, &canonical_json, signature)
{
let ret = match self.inner.ed25519_verify(signing_key, &canonical_json, signature) {
Ok(_) => Ok(()),
Err(_) => Err(SignatureError::VerificationError),
};

View file

@ -108,11 +108,7 @@ pub struct KeysQueryRequest {
impl KeysQueryRequest {
pub(crate) fn new(device_keys: BTreeMap<UserId, Vec<DeviceIdBox>>) -> Self {
Self {
timeout: None,
device_keys,
token: None,
}
Self { timeout: None, device_keys, token: None }
}
}
@ -176,19 +172,13 @@ impl From<SignatureUploadRequest> for OutgoingRequests {
impl From<OutgoingVerificationRequest> for OutgoingRequest {
fn from(r: OutgoingVerificationRequest) -> Self {
Self {
request_id: r.request_id(),
request: Arc::new(r.into()),
}
Self { request_id: r.request_id(), request: Arc::new(r.into()) }
}
}
impl From<SignatureUploadRequest> for OutgoingRequest {
fn from(r: SignatureUploadRequest) -> Self {
Self {
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
}
Self { request_id: Uuid::new_v4(), request: Arc::new(r.into()) }
}
}

View file

@ -104,10 +104,7 @@ impl GroupSessionCache {
room_id: &RoomId,
session_id: &str,
) -> StoreResult<Option<OutboundGroupSession>> {
Ok(self
.get_or_load(room_id)
.await?
.filter(|o| session_id == o.session_id()))
Ok(self.get_or_load(room_id).await?.filter(|o| session_id == o.session_id()))
}
}
@ -126,11 +123,7 @@ impl GroupSessionManager {
const MAX_TO_DEVICE_MESSAGES: usize = 250;
pub(crate) fn new(account: Account, store: Store) -> Self {
Self {
account,
store: store.clone(),
sessions: GroupSessionCache::new(store),
}
Self { account, store: store.clone(), sessions: GroupSessionCache::new(store) }
}
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
@ -230,9 +223,7 @@ impl GroupSessionManager {
Ok((s, None))
}
} else {
self.create_outbound_group_session(room_id, settings)
.await
.map(|(o, i)| (o, i.into()))
self.create_outbound_group_session(room_id, settings).await.map(|(o, i)| (o, i.into()))
}
}
@ -252,13 +243,10 @@ impl GroupSessionManager {
let used_session = match encrypted {
Ok((session, encrypted)) => {
message
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
message.entry(device.user_id().clone()).or_insert_with(BTreeMap::new).insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
Some(session)
}
// TODO we'll want to create m.room_key.withheld here.
@ -270,10 +258,8 @@ impl GroupSessionManager {
Ok((used_session, message))
};
let tasks: Vec<_> = devices
.iter()
.map(|d| spawn(encrypt(d.clone(), content.clone())))
.collect();
let tasks: Vec<_> =
devices.iter().map(|d| spawn(encrypt(d.clone(), content.clone()))).collect();
let results = join_all(tasks).await;
@ -285,20 +271,14 @@ impl GroupSessionManager {
}
for (user, device_messages) in message.into_iter() {
messages
.entry(user)
.or_insert_with(BTreeMap::new)
.extend(device_messages);
messages.entry(user).or_insert_with(BTreeMap::new).extend(device_messages);
}
}
let id = Uuid::new_v4();
let request = ToDeviceRequest {
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
};
let request =
ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages };
trace!(
recipient_count = request.message_count(),
@ -330,20 +310,14 @@ impl GroupSessionManager {
"Calculating group session recipients"
);
let users_shared_with: HashSet<UserId> = outbound
.shared_with_set
.iter()
.map(|k| k.key().clone())
.collect();
let users_shared_with: HashSet<UserId> =
outbound.shared_with_set.iter().map(|k| k.key().clone()).collect();
let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect();
// A user left if a user is missing from the set of users that should
// get the session but is in the set of users that received the session.
let user_left = !users_shared_with
.difference(&users)
.collect::<HashSet<_>>()
.is_empty();
let user_left = !users_shared_with.difference(&users).collect::<HashSet<_>>().is_empty();
let visibility_changed = outbound.settings().history_visibility != history_visibility;
@ -358,10 +332,8 @@ impl GroupSessionManager {
for user_id in users {
let user_devices = self.store.get_user_devices(&user_id).await?;
let non_blacklisted_devices: Vec<Device> = user_devices
.devices()
.filter(|d| !d.is_blacklisted())
.collect();
let non_blacklisted_devices: Vec<Device> =
user_devices.devices().filter(|d| !d.is_blacklisted()).collect();
// If we haven't already concluded that the session should be
// rotated for other reasons, we also need to check whether any
@ -369,10 +341,8 @@ impl GroupSessionManager {
// meantime. If so, we should also rotate the session.
if !should_rotate {
// Device IDs that should receive this session
let non_blacklisted_device_ids: HashSet<&DeviceId> = non_blacklisted_devices
.iter()
.map(|d| d.device_id())
.collect();
let non_blacklisted_device_ids: HashSet<&DeviceId> =
non_blacklisted_devices.iter().map(|d| d.device_id()).collect();
if let Some(shared) = outbound.shared_with_set.get(user_id) {
#[allow(clippy::map_clone)]
@ -388,9 +358,8 @@ impl GroupSessionManager {
//
// represents newly deleted or blacklisted devices. If this
// set is non-empty, we must rotate.
let newly_deleted_or_blacklisted = shared
.difference(&non_blacklisted_device_ids)
.collect::<HashSet<_>>();
let newly_deleted_or_blacklisted =
shared.difference(&non_blacklisted_device_ids).collect::<HashSet<_>>();
if !newly_deleted_or_blacklisted.is_empty() {
should_rotate = true;
@ -398,10 +367,7 @@ impl GroupSessionManager {
};
}
devices
.entry(user_id.clone())
.or_insert_with(Vec::new)
.extend(non_blacklisted_devices);
devices.entry(user_id.clone()).or_insert_with(Vec::new).extend(non_blacklisted_devices);
}
debug!(
@ -461,25 +427,22 @@ impl GroupSessionManager {
let history_visibility = encryption_settings.history_visibility.clone();
let mut changes = Changes::default();
let (outbound, inbound) = self
.get_or_create_outbound_session(room_id, encryption_settings.clone())
.await?;
let (outbound, inbound) =
self.get_or_create_outbound_session(room_id, encryption_settings.clone()).await?;
if let Some(inbound) = inbound {
changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound);
}
let (should_rotate, devices) = self
.collect_session_recipients(users, history_visibility, &outbound)
.await?;
let (should_rotate, devices) =
self.collect_session_recipients(users, history_visibility, &outbound).await?;
let outbound = if should_rotate {
let old_session_id = outbound.session_id();
let (outbound, inbound) = self
.create_outbound_group_session(room_id, encryption_settings)
.await?;
let (outbound, inbound) =
self.create_outbound_group_session(room_id, encryption_settings).await?;
changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound);
@ -514,9 +477,7 @@ impl GroupSessionManager {
if !devices.is_empty() {
let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| {
acc.entry(d.user_id())
.or_insert_with(BTreeSet::new)
.insert(d.device_id());
acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
acc
});
@ -625,14 +586,8 @@ mod test {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
machine
.mark_request_as_sent(&uuid, &keys_query)
.await
.unwrap();
machine
.mark_request_as_sent(&uuid, &keys_claim)
.await
.unwrap();
machine.mark_request_as_sent(&uuid, &keys_query).await.unwrap();
machine.mark_request_as_sent(&uuid, &keys_claim).await.unwrap();
machine
}
@ -646,11 +601,7 @@ mod test {
let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
let requests = machine
.share_group_session(
&room_id,
users.clone().into_iter(),
EncryptionSettings::default(),
)
.share_group_session(&room_id, users.clone().into_iter(), EncryptionSettings::default())
.await
.unwrap();

View file

@ -77,11 +77,7 @@ impl SessionManager {
}
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
if let Some(device) = self
.store
.get_device_from_curve_key(sender, curve_key)
.await?
{
if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
let sessions = device.get_sessions().await?;
if let Some(sessions) = sessions {
@ -120,25 +116,16 @@ impl SessionManager {
///
/// If the device was wedged this will queue up a dummy to-device message.
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self
.wedged_devices
.get(user_id)
.map(|d| d.remove(device_id))
.flatten()
.is_some()
{
if self.wedged_devices.get(user_id).map(|d| d.remove(device_id)).flatten().is_some() {
if let Some(device) = self.store.get_device(user_id, device_id).await? {
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
messages
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest {
request_id: id,
@ -347,9 +334,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id);
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
store.save_account(account.clone()).await.unwrap();
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
user_id.clone(),
)));
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id.clone())));
let verification =
VerificationMachine::new(account.clone(), identity.clone(), store.clone());
@ -358,10 +343,7 @@ mod test {
let store = Store::new(user_id.clone(), identity, store, verification);
let account = Account {
inner: account,
store: store.clone(),
};
let account = Account { inner: account, store: store.clone() };
let session_cache = GroupSessionCache::new(store.clone());
@ -405,10 +387,7 @@ mod test {
let response = KeyClaimResponse::new(one_time_keys);
manager
.receive_keys_claim_response(&response)
.await
.unwrap();
manager.receive_keys_claim_response(&response).await.unwrap();
assert!(manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
@ -434,11 +413,7 @@ mod test {
let bob_device = ReadOnlyDevice::from_account(&bob).await;
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601));
manager
.store
.save_devices(&[bob_device.clone()])
.await
.unwrap();
manager.store.save_devices(&[bob_device.clone()]).await.unwrap();
manager.store.save_sessions(&[session]).await.unwrap();
assert!(manager
@ -451,10 +426,7 @@ mod test {
assert!(!manager.users_for_key_claim.contains_key(bob.user_id()));
assert!(!manager.is_device_wedged(&bob_device));
manager
.mark_device_as_wedged(bob_device.user_id(), &curve_key)
.await
.unwrap();
manager.mark_device_as_wedged(bob_device.user_id(), &curve_key).await.unwrap();
assert!(manager.is_device_wedged(&bob_device));
assert!(manager.users_for_key_claim.contains_key(bob.user_id()));
@ -480,10 +452,7 @@ mod test {
assert!(manager.outgoing_to_device_requests.is_empty());
manager
.receive_keys_claim_response(&response)
.await
.unwrap();
manager.receive_keys_claim_response(&response).await.unwrap();
assert!(!manager.is_device_wedged(&bob_device));
assert!(manager

View file

@ -39,9 +39,7 @@ pub struct SessionStore {
impl SessionStore {
/// Create a new empty Session store.
pub fn new() -> Self {
SessionStore {
entries: Arc::new(DashMap::new()),
}
SessionStore { entries: Arc::new(DashMap::new()) }
}
/// Add a session to the store.
@ -72,8 +70,7 @@ impl SessionStore {
/// Add a list of sessions belonging to the sender key.
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
self.entries.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
}
}
@ -87,9 +84,7 @@ pub struct GroupSessionStore {
impl GroupSessionStore {
/// Create a new empty store.
pub fn new() -> Self {
GroupSessionStore {
entries: Arc::new(DashMap::new()),
}
GroupSessionStore { entries: Arc::new(DashMap::new()) }
}
/// Add an inbound group session to the store.
@ -148,9 +143,7 @@ pub struct DeviceStore {
impl DeviceStore {
/// Create a new empty device store.
pub fn new() -> Self {
DeviceStore {
entries: Arc::new(DashMap::new()),
}
DeviceStore { entries: Arc::new(DashMap::new()) }
}
/// Add a device to the store.
@ -167,19 +160,15 @@ impl DeviceStore {
/// Get the device with the given device_id and belonging to the given user.
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries
.get(user_id)
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
self.entries.get(user_id).and_then(|m| m.get(device_id).map(|d| d.value().clone()))
}
/// Remove the device with the given device_id and belonging to the given user.
/// Remove the device with the given device_id and belonging to the given
/// user.
///
/// Returns the device if it was removed, None if it wasn't in the store.
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries
.get(user_id)
.and_then(|m| m.remove(device_id))
.map(|(_, d)| d)
self.entries.get(user_id).and_then(|m| m.remove(device_id)).map(|(_, d)| d)
}
/// Get a read-only view over all devices of the given user.
@ -240,10 +229,8 @@ mod test {
let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost");
let (outbound, _) = account
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (outbound, _) =
account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
@ -262,9 +249,7 @@ mod test {
let store = GroupSessionStore::new();
store.add(inbound.clone());
let loaded_session = store
.get(&room_id, "test_key", outbound.session_id())
.unwrap();
let loaded_session = store.get(&room_id, "test_key", outbound.session_id()).unwrap();
assert_eq!(inbound, loaded_session);
}

View file

@ -37,10 +37,7 @@ use crate::{
};
fn encode_key_info(info: &RequestedKeyInfo) -> String {
format!(
"{}{}{}{}",
info.room_id, info.sender_key, info.algorithm, info.session_id
)
format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
}
/// An in-memory only store that will forget all the E2EE key once it's dropped.
@ -121,22 +118,14 @@ impl CryptoStore for MemoryStore {
async fn save_changes(&self, mut changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions).await;
self.save_inbound_group_sessions(changes.inbound_group_sessions)
.await;
self.save_inbound_group_sessions(changes.inbound_group_sessions).await;
self.save_devices(changes.devices.new).await;
self.save_devices(changes.devices.changed).await;
self.delete_devices(changes.devices.deleted).await;
for identity in changes
.identities
.new
.drain(..)
.chain(changes.identities.changed)
{
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
for identity in changes.identities.new.drain(..).chain(changes.identities.changed) {
let _ = self.identities.insert(identity.user_id().to_owned(), identity.clone());
}
for hash in changes.message_hashes {
@ -167,9 +156,7 @@ impl CryptoStore for MemoryStore {
sender_key: &str,
session_id: &str,
) -> Result<Option<InboundGroupSession>> {
Ok(self
.inbound_group_sessions
.get(room_id, sender_key, session_id))
Ok(self.inbound_group_sessions.get(room_id, sender_key, session_id))
}
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
@ -250,10 +237,7 @@ impl CryptoStore for MemoryStore {
&self,
request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> {
Ok(self
.outgoing_key_requests
.get(&request_id)
.map(|r| r.clone()))
Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone()))
}
async fn get_key_request_by_info(
@ -278,12 +262,10 @@ impl CryptoStore for MemoryStore {
}
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
self.outgoing_key_requests
.remove(&request_id)
.and_then(|(_, i)| {
let key_info_string = encode_key_info(&i.info);
self.key_requests_by_info.remove(&key_info_string)
});
self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| {
let key_info_string = encode_key_info(&i.info);
self.key_requests_by_info.remove(&key_info_string)
});
Ok(())
}
@ -309,11 +291,7 @@ mod test {
store.save_sessions(vec![session.clone()]).await;
let sessions = store
.get_sessions(&session.sender_key)
.await
.unwrap()
.unwrap();
let sessions = store.get_sessions(&session.sender_key).await.unwrap().unwrap();
let sessions = sessions.lock().await;
let loaded_session = &sessions[0];
@ -326,10 +304,8 @@ mod test {
let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost");
let (outbound, _) = account
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (outbound, _) =
account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
let inbound = InboundGroupSession::new(
"test_key",
"test_key",
@ -340,9 +316,7 @@ mod test {
.unwrap();
let store = MemoryStore::new();
let _ = store
.save_inbound_group_sessions(vec![inbound.clone()])
.await;
let _ = store.save_inbound_group_sessions(vec![inbound.clone()]).await;
let loaded_session = store
.get_inbound_group_session(&room_id, "test_key", outbound.session_id())
@ -359,11 +333,8 @@ mod test {
store.save_devices(vec![device.clone()]).await;
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.unwrap();
let loaded_device =
store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
assert_eq!(device, loaded_device);
@ -377,11 +348,7 @@ mod test {
assert_eq!(&device, loaded_device);
store.delete_devices(vec![device.clone()]).await;
assert!(store
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.is_none());
assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
}
#[tokio::test]
@ -389,14 +356,8 @@ mod test {
let device = get_device();
let store = MemoryStore::new();
assert!(store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(store.is_user_tracked(device.user_id()));
}
@ -405,10 +366,8 @@ mod test {
async fn test_message_hash() {
let store = MemoryStore::new();
let hash = OlmMessageHash {
sender_key: "test_sender".to_owned(),
hash: "test_hash".to_owned(),
};
let hash =
OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
let mut changes = Changes::default();
changes.message_hashes.push(hash.clone());

View file

@ -143,12 +143,7 @@ impl Store {
store: Arc<Box<dyn CryptoStore>>,
verification_machine: VerificationMachine,
) -> Self {
Self {
user_id,
identity,
inner: store,
verification_machine,
}
Self { user_id, identity, inner: store, verification_machine }
}
pub async fn get_readonly_device(
@ -160,10 +155,7 @@ impl Store {
}
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes {
sessions: sessions.to_vec(),
..Default::default()
};
let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
self.save_changes(changes).await
}
@ -171,10 +163,7 @@ impl Store {
#[cfg(test)]
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let changes = Changes {
devices: DeviceChanges {
changed: devices.to_vec(),
..Default::default()
},
devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
..Default::default()
};
@ -186,10 +175,7 @@ impl Store {
&self,
sessions: &[InboundGroupSession],
) -> Result<()> {
let changes = Changes {
inbound_group_sessions: sessions.to_vec(),
..Default::default()
};
let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
self.save_changes(changes).await
}
@ -208,8 +194,7 @@ impl Store {
) -> Result<Option<Device>> {
self.get_user_devices(user_id).await.map(|d| {
d.devices().find(|d| {
d.get_key(DeviceKeyAlgorithm::Curve25519)
.map_or(false, |k| k == curve_key)
d.get_key(DeviceKeyAlgorithm::Curve25519).map_or(false, |k| k == curve_key)
})
})
}
@ -217,12 +202,8 @@ impl Store {
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?;
let own_identity = self
.inner
.get_user_identity(&self.user_id)
.await?
.map(|i| i.own().cloned())
.flatten();
let own_identity =
self.inner.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
let device_owner_identity = self.inner.get_user_identity(user_id).await.ok().flatten();
Ok(UserDevices {
@ -239,24 +220,17 @@ impl Store {
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<Device>> {
let own_identity = self
.get_user_identity(&self.user_id)
.await?
.map(|i| i.own().cloned())
.flatten();
let own_identity =
self.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
let device_owner_identity = self.get_user_identity(user_id).await?;
Ok(self
.inner
.get_device(user_id, device_id)
.await?
.map(|d| Device {
inner: d,
private_identity: self.identity.clone(),
verification_machine: self.verification_machine.clone(),
own_identity,
device_owner_identity,
}))
Ok(self.inner.get_device(user_id, device_id).await?.map(|d| Device {
inner: d,
private_identity: self.identity.clone(),
verification_machine: self.verification_machine.clone(),
own_identity,
device_owner_identity,
}))
}
}
@ -364,7 +338,8 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Get all the inbound group sessions we have stored.
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
/// Get the outobund group sessions we have stored that is used for the given room.
/// Get the outobund group sessions we have stored that is used for the
/// given room.
async fn get_outbound_group_sessions(
&self,
room_id: &RoomId,

View file

@ -113,9 +113,7 @@ impl PickleKey {
/// Get a `PicklingMode` version of this pickle key.
pub fn pickle_mode(&self) -> PicklingMode {
PicklingMode::Encrypted {
key: self.aes256_key.clone(),
}
PicklingMode::Encrypted { key: self.aes256_key.clone() }
}
/// Get the raw AES256 key.
@ -141,10 +139,7 @@ impl PickleKey {
getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key");
let ciphertext = cipher
.encrypt(
&GenericArray::from_slice(nonce.as_ref()),
self.aes256_key.as_slice(),
)
.encrypt(&GenericArray::from_slice(nonce.as_ref()), self.aes256_key.as_slice())
.expect("Can't encrypt pickle key");
EncryptedPickleKey {
@ -180,9 +175,7 @@ impl PickleKey {
}
};
Ok(Self {
aes256_key: decrypted,
})
Ok(Self { aes256_key: decrypted })
}
}

View file

@ -96,13 +96,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> {
[
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
}
}
@ -163,9 +157,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish()
} else {
f.debug_struct("SledStore")
.field("path", &"memory store")
.finish()
f.debug_struct("SledStore").field("path", &"memory store").finish()
}
}
}
@ -252,9 +244,8 @@ impl SledStore {
}
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
let key = if let Some(key) = database
.get("pickle_key".encode())?
.map(|v| serde_json::from_slice(&v))
let key = if let Some(key) =
database.get("pickle_key".encode())?.map(|v| serde_json::from_slice(&v))
{
PickleKey::from_encrypted(passphrase, key?)
.map_err(|_| CryptoStoreError::UnpicklingError)?
@ -296,9 +287,7 @@ impl SledStore {
&self,
room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> {
let account_info = self
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
self.outbound_group_sessions
.get(room_id.encode())?
@ -500,17 +489,11 @@ impl SledStore {
&self,
id: &[u8],
) -> Result<Option<OutgoingKeyRequest>> {
let request = self
.outgoing_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?;
let request =
self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?;
let request = if request.is_none() {
self.unsent_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?
self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
} else {
request
};
@ -552,10 +535,7 @@ impl CryptoStore for SledStore {
*self.account_info.write().unwrap() = Some(account_info);
let changes = Changes {
account: Some(account),
..Default::default()
};
let changes = Changes { account: Some(account), ..Default::default() };
self.save_changes(changes).await
}
@ -578,9 +558,7 @@ impl CryptoStore for SledStore {
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
let account_info = self
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
if self.session_cache.get(sender_key).is_none() {
let sessions: Result<Vec<Session>> = self
@ -612,16 +590,10 @@ impl CryptoStore for SledStore {
session_id: &str,
) -> Result<Option<InboundGroupSession>> {
let key = (room_id.as_str(), sender_key, session_id).encode();
let pickle = self
.inbound_group_sessions
.get(&key)?
.map(|p| serde_json::from_slice(&p));
let pickle = self.inbound_group_sessions.get(&key)?.map(|p| serde_json::from_slice(&p));
if let Some(pickle) = pickle {
Ok(Some(InboundGroupSession::from_pickle(
pickle?,
self.get_pickle_mode(),
)?))
Ok(Some(InboundGroupSession::from_pickle(pickle?, self.get_pickle_mode())?))
} else {
Ok(None)
}
@ -657,10 +629,7 @@ impl CryptoStore for SledStore {
fn users_for_key_query(&self) -> HashSet<UserId> {
#[allow(clippy::map_clone)]
self.users_for_key_query_cache
.iter()
.map(|u| u.clone())
.collect()
self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
@ -714,9 +683,7 @@ impl CryptoStore for SledStore {
}
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
Ok(self
.olm_hashes
.contains_key(serde_json::to_vec(message_hash)?)?)
Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?)
}
async fn get_outgoing_key_request(
@ -752,36 +719,33 @@ impl CryptoStore for SledStore {
}
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
let ret: Result<(), TransactionError<serde_json::Error>> = (
&self.outgoing_key_requests,
&self.unsent_key_requests,
&self.key_requests_by_info,
)
.transaction(
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let ret: Result<(), TransactionError<serde_json::Error>> =
(&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info)
.transaction(
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
Ok(())
},
);
Ok(())
},
);
ret?;
self.inner.flush_async().await?;
@ -846,10 +810,7 @@ mod test {
async fn get_loaded_store() -> (ReadOnlyAccount, SledStore, tempfile::TempDir) {
let (store, dir) = get_store(None).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
(account, store, dir)
}
@ -863,21 +824,12 @@ mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob
.one_time_keys()
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key =
bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice
.create_outbound_session_helper(&sender_key, &one_time_key)
.await
.unwrap();
let session =
alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
(alice, session)
}
@ -895,10 +847,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none());
let account = get_account();
store
.save_account(account)
.await
.expect("Can't save account");
store.save_account(account).await.expect("Can't save account");
}
#[async_test]
@ -906,10 +855,7 @@ mod test {
let (store, _dir) = get_store(None).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
@ -922,10 +868,7 @@ mod test {
let (store, _dir) = get_store(Some("secret_passphrase")).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
@ -938,50 +881,32 @@ mod test {
let (store, _dir) = get_store(None).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
account.mark_as_shared();
account.update_uploaded_key_count(50);
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
assert_eq!(account, loaded_account);
assert_eq!(
account.uploaded_key_count(),
loaded_account.uploaded_key_count()
);
assert_eq!(account.uploaded_key_count(), loaded_account.uploaded_key_count());
}
#[async_test]
async fn load_sessions() {
let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await;
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let changes = Changes {
sessions: vec![session.clone()],
..Default::default()
};
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
let sessions = store
.get_sessions(&session.sender_key)
.await
.expect("Can't load sessions")
.unwrap();
let sessions =
store.get_sessions(&session.sender_key).await.expect("Can't load sessions").unwrap();
let loaded_session = sessions.lock().await.get(0).cloned().unwrap();
assert_eq!(&session, &loaded_session);
@ -994,15 +919,9 @@ mod test {
let sender_key = session.sender_key.to_owned();
let session_id = session.session_id().to_owned();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let changes = Changes {
sessions: vec![session.clone()],
..Default::default()
};
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
@ -1040,15 +959,9 @@ mod test {
)
.expect("Can't create session");
let changes = Changes {
inbound_group_sessions: vec![session],
..Default::default()
};
let changes = Changes { inbound_group_sessions: vec![session], ..Default::default() };
store
.save_changes(changes)
.await
.expect("Can't save group session");
store.save_changes(changes).await.expect("Can't save group session");
}
#[async_test]
@ -1072,15 +985,10 @@ mod test {
let session = InboundGroupSession::from_export(export).unwrap();
let changes = Changes {
inbound_group_sessions: vec![session.clone()],
..Default::default()
};
let changes =
Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() };
store
.save_changes(changes)
.await
.expect("Can't save group session");
store.save_changes(changes).await.expect("Can't save group session");
drop(store);
@ -1103,21 +1011,12 @@ mod test {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
assert!(store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(store.is_user_tracked(device.user_id()));
assert!(!store.users_for_key_query().contains(device.user_id()));
assert!(!store
.update_tracked_user(device.user_id(), true)
.await
.unwrap());
assert!(!store.update_tracked_user(device.user_id(), true).await.unwrap());
assert!(store.users_for_key_query().contains(device.user_id()));
drop(store);
@ -1128,10 +1027,7 @@ mod test {
assert!(store.is_user_tracked(device.user_id()));
assert!(store.users_for_key_query().contains(device.user_id()));
store
.update_tracked_user(device.user_id(), false)
.await
.unwrap();
store.update_tracked_user(device.user_id(), false).await.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id()));
drop(store);
@ -1148,10 +1044,7 @@ mod test {
let device = get_device();
let changes = Changes {
devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
..Default::default()
};
@ -1163,11 +1056,8 @@ mod test {
store.load_account().await.unwrap();
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.unwrap();
let loaded_device =
store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
assert_eq!(device, loaded_device);
@ -1188,20 +1078,14 @@ mod test {
let device = get_device();
let changes = Changes {
devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
..Default::default()
};
store.save_changes(changes).await.unwrap();
let changes = Changes {
devices: DeviceChanges {
deleted: vec![device.clone()],
..Default::default()
},
devices: DeviceChanges { deleted: vec![device.clone()], ..Default::default() },
..Default::default()
};
@ -1212,10 +1096,7 @@ mod test {
store.load_account().await.unwrap();
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap();
let loaded_device = store.get_device(device.user_id(), device.device_id()).await.unwrap();
assert!(loaded_device.is_none());
}
@ -1232,10 +1113,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id);
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let own_identity = get_own_identity();
@ -1247,10 +1125,7 @@ mod test {
..Default::default()
};
store
.save_changes(changes)
.await
.expect("Can't save identity");
store.save_changes(changes).await.expect("Can't save identity");
drop(store);
@ -1258,17 +1133,10 @@ mod test {
store.load_account().await.unwrap();
let loaded_user = store
.get_user_identity(own_identity.user_id())
.await
.unwrap()
.unwrap();
let loaded_user = store.get_user_identity(own_identity.user_id()).await.unwrap().unwrap();
assert_eq!(loaded_user.master_key(), own_identity.master_key());
assert_eq!(
loaded_user.self_signing_key(),
own_identity.self_signing_key()
);
assert_eq!(loaded_user.self_signing_key(), own_identity.self_signing_key());
assert_eq!(loaded_user, own_identity.clone().into());
let other_identity = get_other_identity();
@ -1283,17 +1151,10 @@ mod test {
store.save_changes(changes).await.unwrap();
let loaded_user = store
.get_user_identity(other_identity.user_id())
.await
.unwrap()
.unwrap();
let loaded_user = store.get_user_identity(other_identity.user_id()).await.unwrap().unwrap();
assert_eq!(loaded_user.master_key(), other_identity.master_key());
assert_eq!(
loaded_user.self_signing_key(),
other_identity.self_signing_key()
);
assert_eq!(loaded_user.self_signing_key(), other_identity.self_signing_key());
assert_eq!(loaded_user, other_identity.into());
own_identity.mark_as_verified();
@ -1317,10 +1178,7 @@ mod test {
assert!(store.load_identity().await.unwrap().is_none());
let identity = PrivateCrossSigningIdentity::new(alice_id()).await;
let changes = Changes {
private_identity: Some(identity.clone()),
..Default::default()
};
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
store.save_changes(changes).await.unwrap();
let loaded_identity = store.load_identity().await.unwrap().unwrap();
@ -1331,10 +1189,8 @@ mod test {
async fn olm_hash_saving() {
let (_, store, _dir) = get_loaded_store().await;
let hash = OlmMessageHash {
sender_key: "test_sender".to_owned(),
hash: "test_hash".to_owned(),
};
let hash =
OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
let mut changes = Changes::default();
changes.message_hashes.push(hash.clone());

View file

@ -83,17 +83,13 @@ impl VerificationMachine {
);
let request = match content.into() {
OutgoingContent::Room(r, c) => RoomMessageRequest {
room_id: r,
txn_id: Uuid::new_v4(),
content: c,
OutgoingContent::Room(r, c) => {
RoomMessageRequest { room_id: r, txn_id: Uuid::new_v4(), content: c }.into()
}
.into(),
OutgoingContent::ToDevice(c) => {
let request = content_to_request(device.user_id(), device.device_id(), c);
self.verifications
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
request.into()
}
@ -134,10 +130,7 @@ impl VerificationMachine {
let request = content_to_request(recipient, recipient_device, c);
let request_id = request.txn_id;
let request = OutgoingRequest {
request_id,
request: Arc::new(request.into()),
};
let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };
self.outgoing_to_device_messages.insert(request_id, request);
}
@ -147,12 +140,7 @@ impl VerificationMachine {
let request = OutgoingRequest {
request: Arc::new(
RoomMessageRequest {
room_id: r,
txn_id: request_id,
content: c,
}
.into(),
RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(),
),
request_id,
};
@ -181,32 +169,22 @@ impl VerificationMachine {
}
pub fn outgoing_room_message_requests(&self) -> Vec<OutgoingRequest> {
self.outgoing_room_messages
.iter()
.map(|r| (*r).clone())
.collect()
self.outgoing_room_messages.iter().map(|r| (*r).clone()).collect()
}
pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> {
#[allow(clippy::map_clone)]
self.outgoing_to_device_messages
.iter()
.map(|r| (*r).clone())
.collect()
self.outgoing_to_device_messages.iter().map(|r| (*r).clone()).collect()
}
pub fn garbage_collect(&self) {
self.verifications
.retain(|_, s| !(s.is_done() || s.is_canceled()));
self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
for sas in self.verifications.iter() {
if let Some(r) = sas.cancel_if_timed_out() {
self.outgoing_to_device_messages.insert(
r.request_id(),
OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
},
OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
);
}
}
@ -266,10 +244,8 @@ impl VerificationMachine {
);
if let Some((_, request)) = self.requests.remove(&e.content.relation.event_id) {
if let Some(d) = self
.store
.get_device(&e.sender, &e.content.from_device)
.await?
if let Some(d) =
self.store.get_device(&e.sender, &e.content.from_device).await?
{
match request.into_started_sas(
e,
@ -296,7 +272,8 @@ impl VerificationMachine {
"Can't start key verification with {} {}, canceling: {:?}",
e.sender, e.content.from_device, c
);
// self.queue_up_content(&e.sender, &e.content.from_device, c)
// self.queue_up_content(&e.sender,
// &e.content.from_device, c)
}
}
}
@ -375,11 +352,7 @@ impl VerificationMachine {
e.content.from_device
);
if let Some(d) = self
.store
.get_device(&e.sender, &e.content.from_device)
.await?
{
if let Some(d) = self.store.get_device(&e.sender, &e.content.from_device).await? {
let private_identity = self.private_identity.lock().await.clone();
match Sas::from_start_event(
self.account.clone(),
@ -390,8 +363,7 @@ impl VerificationMachine {
self.store.get_user_identity(&e.sender).await?,
) {
Ok(s) => {
self.verifications
.insert(e.content.transaction_id.clone(), s);
self.verifications.insert(e.content.transaction_id.clone(), s);
}
Err(c) => {
warn!(
@ -442,10 +414,7 @@ impl VerificationMachine {
self.outgoing_to_device_messages.insert(
request_id,
OutgoingRequest {
request_id,
request: Arc::new(r.into()),
},
OutgoingRequest { request_id, request: Arc::new(r.into()) },
);
}
}
@ -521,10 +490,7 @@ mod test {
);
machine
.receive_event(&wrap_any_to_device_content(
bob_sas.user_id(),
start_content.into(),
))
.receive_event(&wrap_any_to_device_content(bob_sas.user_id(), start_content.into()))
.await
.unwrap();
@ -559,11 +525,7 @@ mod test {
alice_machine.receive_event(&event).await.unwrap();
assert!(!alice_machine.outgoing_to_device_messages.is_empty());
let request = alice_machine
.outgoing_to_device_messages
.iter()
.next()
.unwrap();
let request = alice_machine.outgoing_to_device_messages.iter().next().unwrap();
let txn_id = *request.request_id();

View file

@ -56,11 +56,7 @@ pub(crate) mod test {
sender: &UserId,
content: OutgoingContent,
) -> AnyToDeviceEvent {
let content = if let OutgoingContent::ToDevice(c) = content {
c
} else {
unreachable!()
};
let content = if let OutgoingContent::ToDevice(c) = content { c } else { unreachable!() };
match content {
AnyToDeviceEventContent::KeyVerificationKey(c) => {
@ -95,22 +91,11 @@ pub(crate) mod test {
pub(crate) fn get_content_from_request(
request: &OutgoingVerificationRequest,
) -> OutgoingContent {
let request = if let OutgoingVerificationRequest::ToDevice(r) = request {
r
} else {
unreachable!()
};
let request =
if let OutgoingVerificationRequest::ToDevice(r) = request { r } else { unreachable!() };
let json: Value = serde_json::from_str(
request
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
request.messages.values().next().unwrap().values().next().unwrap().get(),
)
.unwrap();

View file

@ -106,15 +106,13 @@ impl VerificationRequest {
content: &KeyVerificationRequestEventContent,
) -> Self {
Self {
inner: Arc::new(Mutex::new(InnerRequest::Requested(
RequestState::from_request_event(
account.user_id(),
account.device_id(),
sender,
event_id,
content,
),
))),
inner: Arc::new(Mutex::new(InnerRequest::Requested(RequestState::from_request_event(
account.user_id(),
account.device_id(),
sender,
event_id,
content,
)))),
account,
other_user_id: sender.clone().into(),
private_cross_signing_identity,
@ -285,10 +283,7 @@ impl RequestState<Created> {
own_user_id: self.own_user_id,
own_device_id: self.own_device_id,
other_user_id: self.other_user_id,
state: Sent {
methods: SUPPORTED_METHODS.to_vec(),
flow_id: response.event_id.clone(),
},
state: Sent { methods: SUPPORTED_METHODS.to_vec(), flow_id: response.event_id.clone() },
}
}
}
@ -368,9 +363,7 @@ impl RequestState<Requested> {
let content = ReadyEventContent {
from_device: self.own_device_id,
methods: self.state.methods,
relation: Relation {
event_id: self.state.flow_id,
},
relation: Relation { event_id: self.state.flow_id },
};
(state, content)
@ -580,9 +573,7 @@ mod test {
panic!("Invalid start event content type");
};
let alice_sas = alice_request
.into_started_sas(&event, bob_device, None)
.unwrap();
let alice_sas = alice_request.into_started_sas(&event, bob_device, None).unwrap();
assert!(!bob_sas.is_canceled());
assert!(!alice_sas.is_canceled());

View file

@ -61,10 +61,7 @@ impl StartContent {
StartContent::Room(_, c) => serde_json::to_value(c),
};
content
.expect("Can't serialize content")
.try_into()
.expect("Can't canonicalize content")
content.expect("Can't serialize content").try_into().expect("Can't canonicalize content")
}
}

View file

@ -62,12 +62,7 @@ pub fn calculate_commitment(public_key: &str, content: impl Into<StartContent>)
let content = content.into().canonical_json();
let content_string = content.to_string();
encode(
Sha256::new()
.chain(&public_key)
.chain(&content_string)
.finalize(),
)
encode(Sha256::new().chain(&public_key).chain(&content_string).finalize())
}
/// Get a tuple of an emoji and a description of the emoji using a number.
@ -231,11 +226,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC")
{
trace!(
"Successfully verified the device key {} from {}",
key_id,
sender
);
trace!("Successfully verified the device key {} from {}", key_id, sender);
verified_devices.push(ids.other_device.clone());
} else {
@ -250,11 +241,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC")
{
trace!(
"Successfully verified the master key {} from {}",
key_id,
sender
);
trace!("Successfully verified the master key {} from {}", key_id, sender);
verified_identities.push(identity.clone())
} else {
return Err(CancelCode::KeyMismatch);
@ -316,8 +303,7 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
mac.insert(
key_id.to_string(),
sas.calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC"),
sas.calculate_mac(key, &format!("{}{}", info, key_id)).expect("Can't calculate SAS MAC"),
);
// TODO Add the cross signing master key here if we trust/have it.
@ -329,23 +315,13 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
.expect("Can't calculate SAS MAC");
match flow_id {
FlowId::ToDevice(s) => MacToDeviceEventContent {
transaction_id: s.to_string(),
keys,
mac,
FlowId::ToDevice(s) => {
MacToDeviceEventContent { transaction_id: s.to_string(), keys, mac }.into()
}
FlowId::InRoom(r, e) => {
(r.clone(), MacEventContent { mac, keys, relation: Relation { event_id: e.clone() } })
.into()
}
.into(),
FlowId::InRoom(r, e) => (
r.clone(),
MacEventContent {
mac,
keys,
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
}
}
@ -366,24 +342,12 @@ fn extra_info_sas(
flow_id: &str,
we_started: bool,
) -> String {
let our_info = format!(
"{}|{}|{}",
ids.account.user_id(),
ids.account.device_id(),
own_pubkey
);
let their_info = format!(
"{}|{}|{}",
ids.other_device.user_id(),
ids.other_device.device_id(),
their_pubkey
);
let our_info = format!("{}|{}|{}", ids.account.user_id(), ids.account.device_id(), own_pubkey);
let their_info =
format!("{}|{}|{}", ids.other_device.user_id(), ids.other_device.device_id(), their_pubkey);
let (first_info, second_info) = if we_started {
(our_info, their_info)
} else {
(their_info, our_info)
};
let (first_info, second_info) =
if we_started { (our_info, their_info) } else { (their_info, our_info) };
let info = format!(
"MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}",
@ -580,11 +544,7 @@ pub fn content_to_request(
_ => unreachable!(),
};
ToDeviceRequest {
txn_id: Uuid::new_v4(),
event_type,
messages,
}
ToDeviceRequest { txn_id: Uuid::new_v4(), event_type, messages }
}
#[cfg(test)]
@ -622,18 +582,14 @@ mod test {
#[test]
fn emoji_generation() {
let bytes = vec![0, 0, 0, 0, 0, 0];
let index: Vec<(&'static str, &'static str)> = vec![0, 0, 0, 0, 0, 0, 0]
.into_iter()
.map(emoji_from_index)
.collect();
let index: Vec<(&'static str, &'static str)> =
vec![0, 0, 0, 0, 0, 0, 0].into_iter().map(emoji_from_index).collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref());
let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let index: Vec<(&'static str, &'static str)> = vec![63, 63, 63, 63, 63, 63, 63]
.into_iter()
.map(emoji_from_index)
.collect();
let index: Vec<(&'static str, &'static str)> =
vec![63, 63, 63, 63, 63, 63, 63].into_iter().map(emoji_from_index).collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref());
}

View file

@ -278,14 +278,15 @@ impl InnerSas {
_ => (self, None),
},
AnyToDeviceEvent::KeyVerificationMac(e) => match self {
InnerSas::KeyRecieved(s) => match s.into_mac_received(&e.sender, e.content.clone())
{
Ok(s) => (InnerSas::MacReceived(s), None),
Err(s) => {
let content = s.as_content();
(InnerSas::Canceled(s), Some(content.into()))
InnerSas::KeyRecieved(s) => {
match s.into_mac_received(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::MacReceived(s), None),
Err(s) => {
let content = s.as_content();
(InnerSas::Canceled(s), Some(content.into()))
}
}
},
}
InnerSas::Confirmed(s) => match s.into_done(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::Done(s), None),
Err(s) => {

View file

@ -150,11 +150,8 @@ impl Sas {
store: Arc<Box<dyn CryptoStore>>,
other_identity: Option<UserIdentities>,
) -> (Sas, StartContent) {
let (inner, content) = InnerSas::start(
account.clone(),
other_device.clone(),
other_identity.clone(),
);
let (inner, content) =
InnerSas::start(account.clone(), other_device.clone(), other_identity.clone());
(
Self::start_helper(
@ -266,22 +263,18 @@ impl Sas {
&self,
settings: AcceptSettings,
) -> Option<OutgoingVerificationRequest> {
self.inner
.lock()
.unwrap()
.accept()
.map(|c| match settings.apply(c) {
AcceptContent::ToDevice(c) => {
let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
self.content_to_request(content).into()
}
AcceptContent::Room(room_id, content) => RoomMessageRequest {
room_id,
txn_id: Uuid::new_v4(),
content: AnyMessageEventContent::KeyVerificationAccept(content),
}
.into(),
})
self.inner.lock().unwrap().accept().map(|c| match settings.apply(c) {
AcceptContent::ToDevice(c) => {
let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
self.content_to_request(content).into()
}
AcceptContent::Room(room_id, content) => RoomMessageRequest {
room_id,
txn_id: Uuid::new_v4(),
content: AnyMessageEventContent::KeyVerificationAccept(content),
}
.into(),
})
}
/// Confirm the Sas verification.
@ -294,10 +287,7 @@ impl Sas {
pub async fn confirm(
&self,
) -> Result<
(
Option<OutgoingVerificationRequest>,
Option<SignatureUploadRequest>,
),
(Option<OutgoingVerificationRequest>, Option<SignatureUploadRequest>),
CryptoStoreError,
> {
let (content, done) = {
@ -310,9 +300,9 @@ impl Sas {
};
let mac_request = content.map(|c| match c {
event_enums::MacContent::ToDevice(c) => self
.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c))
.into(),
event_enums::MacContent::ToDevice(c) => {
self.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c)).into()
}
event_enums::MacContent::Room(r, c) => RoomMessageRequest {
room_id: r,
txn_id: Uuid::new_v4(),
@ -365,10 +355,7 @@ impl Sas {
};
let mut changes = Changes {
devices: DeviceChanges {
changed: vec![device],
..Default::default()
},
devices: DeviceChanges { changed: vec![device], ..Default::default() },
..Default::default()
};
@ -428,10 +415,7 @@ impl Sas {
.map(VerificationResult::SignatureUpload)
.unwrap_or(VerificationResult::Ok))
} else {
Ok(self
.cancel()
.map(VerificationResult::Cancel)
.unwrap_or(VerificationResult::Ok))
Ok(self.cancel().map(VerificationResult::Cancel).unwrap_or(VerificationResult::Ok))
}
}
@ -454,14 +438,8 @@ impl Sas {
.as_ref()
.map_or(false, |i| i.master_key() == identity.master_key())
{
if self
.verified_identities()
.map_or(false, |i| i.contains(&identity))
{
trace!(
"Marking user identity of {} as verified.",
identity.user_id(),
);
if self.verified_identities().map_or(false, |i| i.contains(&identity)) {
trace!("Marking user identity of {} as verified.", identity.user_id(),);
if let UserIdentities::Own(i) = &identity {
i.mark_as_verified();
@ -500,17 +478,11 @@ impl Sas {
pub(crate) async fn mark_device_as_verified(
&self,
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self
.store
.get_device(self.other_user_id(), self.other_device_id())
.await?;
let device = self.store.get_device(self.other_user_id(), self.other_device_id()).await?;
if let Some(device) = device {
if device.keys() == self.other_device.keys() {
if self
.verified_devices()
.map_or(false, |v| v.contains(&device))
{
if self.verified_devices().map_or(false, |v| v.contains(&device)) {
trace!(
"Marking device {} {} as verified.",
device.user_id(),
@ -571,9 +543,9 @@ impl Sas {
content: AnyMessageEventContent::KeyVerificationCancel(content),
}
.into(),
CancelContent::ToDevice(c) => self
.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c))
.into(),
CancelContent::ToDevice(c) => {
self.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c)).into()
}
})
}
@ -704,9 +676,7 @@ impl AcceptSettings {
///
/// * `methods` - The methods this client allows at most
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
Self {
allowed_methods: methods,
}
Self { allowed_methods: methods }
}
fn apply(self, mut content: AcceptContent) -> AcceptContent {
@ -715,15 +685,8 @@ impl AcceptSettings {
method: AcceptMethod::MSasV1(c),
..
})
| AcceptContent::Room(
_,
AcceptEventContent {
method: AcceptMethod::MSasV1(c),
..
},
) => {
c.short_authentication_string
.retain(|sas| self.allowed_methods.contains(sas));
| AcceptContent::Room(_, AcceptEventContent { method: AcceptMethod::MSasV1(c), .. }) => {
c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
content
}
_ => content,
@ -826,13 +789,7 @@ mod test {
);
alice.receive_event(&event);
assert!(alice
.verified_devices()
.unwrap()
.contains(&alice.other_device()));
assert!(bob
.verified_devices()
.unwrap()
.contains(&bob.other_device()));
assert!(alice.verified_devices().unwrap().contains(&alice.other_device()));
assert!(bob.verified_devices().unwrap().contains(&bob.other_device()));
}
}

View file

@ -59,10 +59,8 @@ const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
&[KeyAgreementProtocol::Curve25519HkdfSha256];
const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
const MACS: &[MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256];
const STRINGS: &[ShortAuthenticationString] = &[
ShortAuthenticationString::Decimal,
ShortAuthenticationString::Emoji,
];
const STRINGS: &[ShortAuthenticationString] =
&[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
// The max time a SAS flow can take from start to done.
const MAX_AGE: Duration = Duration::from_secs(60 * 5);
@ -111,9 +109,7 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|| !HASHES.contains(&content.hash)
|| !MACS.contains(&content.message_authentication_code)
|| (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
|| (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
@ -372,11 +368,7 @@ impl SasState<Created> {
) -> SasState<Created> {
SasState {
inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds {
account,
other_device,
other_identity,
},
ids: SasIds { account, other_device, other_identity },
verification_flow_id: flow_id.into(),
creation_time: Arc::new(Instant::now()),
@ -411,9 +403,7 @@ impl SasState<Created> {
MSasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."),
),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
),
}
@ -460,8 +450,8 @@ impl SasState<Created> {
}
impl SasState<Started> {
/// Create a new SAS verification flow from an in-room m.key.verification.start
/// event.
/// Create a new SAS verification flow from an in-room
/// m.key.verification.start event.
///
/// This will put us in the `started` state.
///
@ -502,11 +492,7 @@ impl SasState<Started> {
let sas = SasState {
inner: Arc::new(Mutex::new(sas)),
ids: SasIds {
account,
other_device,
other_identity,
},
ids: SasIds { account, other_device, other_identity },
creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()),
@ -544,11 +530,7 @@ impl SasState<Started> {
creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()),
ids: SasIds {
account,
other_device,
other_identity,
},
ids: SasIds { account, other_device, other_identity },
verification_flow_id: content.flow_id().into(),
state: Arc::new(Canceled::new(CancelCode::UnknownMethod)),
@ -582,19 +564,12 @@ impl SasState<Started> {
);
match self.verification_flow_id.as_ref() {
FlowId::ToDevice(s) => AcceptToDeviceEventContent {
transaction_id: s.to_string(),
method,
FlowId::ToDevice(s) => {
AcceptToDeviceEventContent { transaction_id: s.to_string(), method }.into()
}
.into(),
FlowId::InRoom(r, e) => (
r.clone(),
AcceptEventContent {
method,
relation: Relation {
event_id: e.clone(),
},
},
AcceptEventContent { method, relation: Relation { event_id: e.clone() } },
)
.into(),
}
@ -662,10 +637,8 @@ impl SasState<Accepted> {
self.check_event(&sender, content.flow_id().as_str())
.map_err(|c| self.clone().cancel(c))?;
let commitment = calculate_commitment(
content.public_key(),
self.state.start_content.as_ref().clone(),
);
let commitment =
calculate_commitment(content.public_key(), self.state.start_content.as_ref().clone());
if self.state.commitment != commitment {
Err(self.cancel(CancelCode::InvalidMessage))
@ -707,9 +680,7 @@ impl SasState<Accepted> {
r.clone(),
KeyEventContent {
key: self.inner.lock().unwrap().public_key(),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
)
.into(),
@ -733,9 +704,7 @@ impl SasState<KeyReceived> {
r.clone(),
KeyEventContent {
key: self.inner.lock().unwrap().public_key(),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
)
.into(),
@ -758,8 +727,8 @@ impl SasState<KeyReceived> {
/// Get the index of the emoji of the short authentication string.
///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers
/// can be converted to a unique emoji defined by the spec.
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index(
&self.inner.lock().unwrap(),
@ -930,11 +899,7 @@ impl SasState<Confirmed> {
///
/// The content needs to be automatically sent to the other side.
pub fn as_content(&self) -> MacContent {
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
}
}
@ -993,8 +958,8 @@ impl SasState<MacReceived> {
/// Get the index of the emoji of the short authentication string.
///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers
/// can be converted to a unique emoji defined by the spec.
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index(
&self.inner.lock().unwrap(),
@ -1026,11 +991,7 @@ impl SasState<WaitingForDone> {
/// The content needs to be automatically sent to the other side if it
/// wasn't already sent.
pub fn as_content(&self) -> MacContent {
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
}
pub fn done_content(&self) -> DoneContent {
@ -1038,15 +999,9 @@ impl SasState<WaitingForDone> {
FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications")
}
FlowId::InRoom(r, e) => (
r.clone(),
DoneEventContent {
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
FlowId::InRoom(r, e) => {
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
}
}
}
@ -1088,11 +1043,7 @@ impl SasState<Done> {
/// The content needs to be automatically sent to the other side if it
/// wasn't already sent.
pub fn as_content(&self) -> MacContent {
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
}
pub fn done_content(&self) -> DoneContent {
@ -1100,15 +1051,9 @@ impl SasState<Done> {
FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications")
}
FlowId::InRoom(r, e) => (
r.clone(),
DoneEventContent {
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
FlowId::InRoom(r, e) => {
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
}
}
}
@ -1144,10 +1089,7 @@ impl Canceled {
_ => unimplemented!(),
};
Canceled {
cancel_code: code,
reason,
}
Canceled { cancel_code: code, reason }
}
}
@ -1166,9 +1108,7 @@ impl SasState<Canceled> {
CancelEventContent {
reason: self.state.reason.to_string(),
code: self.state.cancel_code.clone(),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
)
.into(),
@ -1331,9 +1271,7 @@ mod test {
let content = bob.as_content();
let sender = UserId::try_from("@malory:example.org").unwrap();
alice
.into_accepted(&sender, content)
.expect_err("Didn't cancel on a invalid sender");
alice.into_accepted(&sender, content).expect_err("Didn't cancel on a invalid sender");
}
#[tokio::test]

View file

@ -152,10 +152,7 @@ impl EventBuilder {
}
fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) {
self.joined_room_events
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self.joined_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
}
pub fn add_custom_invited_event(
@ -164,10 +161,7 @@ impl EventBuilder {
event: serde_json::Value,
) -> &mut Self {
let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap();
self.invited_room_events
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self.invited_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
self
}
@ -177,10 +171,7 @@ impl EventBuilder {
event: serde_json::Value,
) -> &mut Self {
let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap();
self.left_room_events
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self.left_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
self
}
@ -350,9 +341,7 @@ impl EventBuilder {
pub fn build_sync_response(&mut self) -> SyncResponse {
let body = self.build_json_sync_response();
let response = Response::builder()
.body(serde_json::to_vec(&body).unwrap())
.unwrap();
let response = Response::builder().body(serde_json::to_vec(&body).unwrap()).unwrap();
SyncResponse::try_from_http_response(response).unwrap()
}
@ -393,15 +382,10 @@ pub fn sync_response(kind: SyncResponseFile) -> SyncResponse {
SyncResponseFile::Voip => &test_json::VOIP_SYNC,
};
let response = Response::builder()
.body(data.to_string().as_bytes().to_vec())
.unwrap();
let response = Response::builder().body(data.to_string().as_bytes().to_vec()).unwrap();
SyncResponse::try_from_http_response(response).unwrap()
}
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
}

View file

@ -1,5 +1,6 @@
comment_width = 100
max_width = 100
comment_width = 80
wrap_comments = true
imports_granularity = "Crate"
max_width = 100
use_small_heuristics = "Max"
group_imports = "StdExternalCrate"