Add use_small_heuristics option and run fmt

master
Devin Ragotzy 2021-05-12 11:00:47 -04:00
parent c85f4d4f0c
commit 2ef0c2959c
69 changed files with 1563 additions and 3660 deletions

View File

@ -72,15 +72,11 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new_with_config(homeserver_url, client_config).unwrap(); let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client client.login(username, password, None, Some("autojoin bot")).await?;
.login(username, password, None, Some("autojoin bot"))
.await?;
println!("logged in as {}", username); println!("logged in as {}", username);
client client.set_event_handler(Box::new(AutoJoinBot::new(client.clone()))).await;
.set_event_handler(Box::new(AutoJoinBot::new(client.clone())))
.await;
client.sync(SyncSettings::default()).await; client.sync(SyncSettings::default()).await;

View File

@ -69,9 +69,7 @@ async fn login_and_sync(
// create a new Client with the given homeserver url and config // create a new Client with the given homeserver url and config
let client = Client::new_with_config(homeserver_url, client_config).unwrap(); let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client client.login(&username, &password, None, Some("command bot")).await?;
.login(&username, &password, None, Some("command bot"))
.await?;
println!("logged in as {}", username); println!("logged in as {}", username);

View File

@ -21,11 +21,7 @@ fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> Aut
auth_parameters.insert("identifier".to_owned(), identifier); auth_parameters.insert("identifier".to_owned(), identifier);
auth_parameters.insert("password".to_owned(), password.to_owned().into()); auth_parameters.insert("password".to_owned(), password.to_owned().into());
AuthData::DirectRequest { AuthData::DirectRequest { kind: "m.login.password", auth_parameters, session }
kind: "m.login.password",
auth_parameters,
session,
}
} }
async fn bootstrap(client: Client, user_id: UserId, password: String) { async fn bootstrap(client: Client, user_id: UserId, password: String) {
@ -33,9 +29,7 @@ async fn bootstrap(client: Client, user_id: UserId, password: String) {
let mut input = String::new(); let mut input = String::new();
io::stdin() io::stdin().read_line(&mut input).expect("error: unable to read user input");
.read_line(&mut input)
.expect("error: unable to read user input");
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
if let Err(e) = client.bootstrap_cross_signing(None).await { if let Err(e) = client.bootstrap_cross_signing(None).await {
@ -62,9 +56,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
let response = client let response = client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
let user_id = &response.user_id; let user_id = &response.user_id;
let client_ref = &client; let client_ref = &client;

View File

@ -19,9 +19,7 @@ async fn wait_for_confirmation(client: Client, sas: Sas) {
println!("Does the emoji match: {:?}", sas.emoji()); println!("Does the emoji match: {:?}", sas.emoji());
let mut input = String::new(); let mut input = String::new();
io::stdin() io::stdin().read_line(&mut input).expect("error: unable to read user input");
.read_line(&mut input)
.expect("error: unable to read user input");
match input.trim().to_lowercase().as_ref() { match input.trim().to_lowercase().as_ref() {
"yes" | "true" | "ok" => { "yes" | "true" | "ok" => {
@ -68,9 +66,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
let client_ref = &client; let client_ref = &client;
let initial_sync = Arc::new(AtomicBool::from(true)); let initial_sync = Arc::new(AtomicBool::from(true));
@ -81,12 +77,7 @@ async fn login(
let client = &client_ref; let client = &client_ref;
let initial = &initial_ref; let initial = &initial_ref;
for event in response for event in response.to_device.events.iter().filter_map(|e| e.deserialize().ok()) {
.to_device
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
match event { match event {
AnyToDeviceEvent::KeyVerificationStart(e) => { AnyToDeviceEvent::KeyVerificationStart(e) => {
let sas = client let sas = client
@ -129,11 +120,8 @@ async fn login(
if !initial.load(Ordering::SeqCst) { if !initial.load(Ordering::SeqCst) {
for (_room_id, room_info) in response.rooms.join { for (_room_id, room_info) in response.rooms.join {
for event in room_info for event in
.timeline room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
{ {
if let AnySyncRoomEvent::Message(event) = event { if let AnySyncRoomEvent::Message(event) = event {
match event { match event {

View File

@ -28,10 +28,7 @@ async fn get_profile(client: Client, mxid: &UserId) -> MatrixResult<UserProfile>
// Use the response and construct a UserProfile struct. // Use the response and construct a UserProfile struct.
// See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html // See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html
// for details on the Response for this Request // for details on the Response for this Request
let user_profile = UserProfile { let user_profile = UserProfile { avatar_url: resp.avatar_url, displayname: resp.displayname };
avatar_url: resp.avatar_url,
displayname: resp.displayname,
};
Ok(user_profile) Ok(user_profile)
} }
@ -43,9 +40,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
Ok(client) Ok(client)
} }

View File

@ -52,9 +52,7 @@ impl EventHandler for ImageBot {
println!("sending image"); println!("sending image");
let mut image = self.image.lock().await; let mut image = self.image.lock().await;
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None) room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap();
.await
.unwrap();
image.seek(SeekFrom::Start(0)).unwrap(); image.seek(SeekFrom::Start(0)).unwrap();
@ -73,14 +71,10 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client client.login(&username, &password, None, Some("command bot")).await?;
.login(&username, &password, None, Some("command bot"))
.await?;
client.sync_once(SyncSettings::default()).await.unwrap(); client.sync_once(SyncSettings::default()).await.unwrap();
client client.set_event_handler(Box::new(ImageBot::new(image))).await;
.set_event_handler(Box::new(ImageBot::new(image)))
.await;
let settings = SyncSettings::default().token(client.sync_token().await.unwrap()); let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
client.sync(settings).await; client.sync(settings).await;
@ -91,12 +85,8 @@ async fn login_and_sync(
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), matrix_sdk::Error> { async fn main() -> Result<(), matrix_sdk::Error> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let (homeserver_url, username, password, image_path) = match ( let (homeserver_url, username, password, image_path) =
env::args().nth(1), match (env::args().nth(1), env::args().nth(2), env::args().nth(3), env::args().nth(4)) {
env::args().nth(2),
env::args().nth(3),
env::args().nth(4),
) {
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d), (Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
_ => { _ => {
eprintln!( eprintln!(
@ -107,10 +97,7 @@ async fn main() -> Result<(), matrix_sdk::Error> {
} }
}; };
println!( println!("helloooo {} {} {} {:#?}", homeserver_url, username, password, image_path);
"helloooo {} {} {} {:#?}",
homeserver_url, username, password, image_path
);
let path = PathBuf::from(image_path); let path = PathBuf::from(image_path);
let image = File::open(path).expect("Can't open image file."); let image = File::open(path).expect("Can't open image file.");

View File

@ -28,9 +28,7 @@ impl EventHandler for EventCallback {
} = event } = event
{ {
let member = room.get_member(&sender).await.unwrap().unwrap(); let member = room.get_member(&sender).await.unwrap().unwrap();
let name = member let name = member.display_name().unwrap_or_else(|| member.user_id().as_str());
.display_name()
.unwrap_or_else(|| member.user_id().as_str());
println!("{}: {}", name, msg_body); println!("{}: {}", name, msg_body);
} }
} }
@ -47,9 +45,7 @@ async fn login(
client.set_event_handler(Box::new(EventCallback)).await; client.set_event_handler(Box::new(EventCallback)).await;
client client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
client.sync(SyncSettings::new()).await; client.sync(SyncSettings::new()).await;
Ok(()) Ok(())

File diff suppressed because it is too large Load Diff

View File

@ -65,10 +65,7 @@ impl Device {
let (sas, request) = self.inner.start_verification().await?; let (sas, request) = self.inner.start_verification().await?;
self.client.send_to_device(&request).await?; self.client.send_to_device(&request).await?;
Ok(Sas { Ok(Sas { inner: sas, client: self.client.clone() })
inner: sas,
client: self.client.clone(),
})
} }
/// Is the device trusted. /// Is the device trusted.
@ -102,10 +99,7 @@ pub struct UserDevices {
impl UserDevices { impl UserDevices {
/// Get the specific device with the given device id. /// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<Device> { pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.inner.get(device_id).map(|d| Device { self.inner.get(device_id).map(|d| Device { inner: d, client: self.client.clone() })
inner: d,
client: self.client.clone(),
})
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
@ -117,9 +111,6 @@ impl UserDevices {
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ { pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
let client = self.client.clone(); let client = self.client.clone();
self.inner.devices().map(move |d| Device { self.inner.devices().map(move |d| Device { inner: d, client: client.clone() })
inner: d,
client: client.clone(),
})
} }
} }

View File

@ -43,11 +43,13 @@ pub enum HttpError {
#[error(transparent)] #[error(transparent)]
Reqwest(#[from] ReqwestError), Reqwest(#[from] ReqwestError),
/// Queried endpoint requires authentication but was called on an anonymous client. /// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")] #[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired, AuthenticationRequired,
/// Client tried to force authentication but did not provide an access token. /// Client tried to force authentication but did not provide an access
/// token.
#[error("tried to force authentication but no access token was provided")] #[error("tried to force authentication but no access token was provided")]
ForcedAuthenticationWithoutAccessToken, ForcedAuthenticationWithoutAccessToken,
@ -69,9 +71,10 @@ pub enum HttpError {
/// An error occurred while authenticating. /// An error occurred while authenticating.
/// ///
/// When registering or authenticating the Matrix server can send a `UiaaResponse` /// When registering or authenticating the Matrix server can send a
/// as the error type, this is a User-Interactive Authentication API response. This /// `UiaaResponse` as the error type, this is a User-Interactive
/// represents an error with information about how to authenticate the user. /// Authentication API response. This represents an error with
/// information about how to authenticate the user.
#[error(transparent)] #[error(transparent)]
UiaaError(#[from] FromHttpResponseError<UiaaError>), UiaaError(#[from] FromHttpResponseError<UiaaError>),
@ -96,7 +99,8 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
Http(#[from] HttpError), Http(#[from] HttpError),
/// Queried endpoint requires authentication but was called on an anonymous client. /// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")] #[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired, AuthenticationRequired,

View File

@ -87,39 +87,24 @@ impl Handler {
for (room_id, room_info) in &response.rooms.join { for (room_id, room_info) in &response.rooms.join {
if let Some(room) = self.get_room(room_id) { if let Some(room) = self.get_room(room_id) {
for event in room_info for event in room_info.ephemeral.events.iter().filter_map(|e| e.deserialize().ok())
.ephemeral
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_ephemeral_event(room.clone(), &event).await; self.handle_ephemeral_event(room.clone(), &event).await;
} }
for event in room_info for event in
.account_data room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_room_account_data_event(room.clone(), &event) self.handle_room_account_data_event(room.clone(), &event)
.await; .await;
} }
for event in room_info for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
self.handle_state_event(room.clone(), &event).await; self.handle_state_event(room.clone(), &event).await;
} }
for event in room_info for event in
.timeline room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
{ {
self.handle_timeline_event(room.clone(), &event).await; self.handle_timeline_event(room.clone(), &event).await;
} }
@ -128,30 +113,19 @@ impl Handler {
for (room_id, room_info) in &response.rooms.leave { for (room_id, room_info) in &response.rooms.leave {
if let Some(room) = self.get_room(room_id) { if let Some(room) = self.get_room(room_id) {
for event in room_info for event in
.account_data room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_room_account_data_event(room.clone(), &event) self.handle_room_account_data_event(room.clone(), &event)
.await; .await;
} }
for event in room_info for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
self.handle_state_event(room.clone(), &event).await; self.handle_state_event(room.clone(), &event).await;
} }
for event in room_info for event in
.timeline room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
{ {
self.handle_timeline_event(room.clone(), &event).await; self.handle_timeline_event(room.clone(), &event).await;
} }
@ -160,31 +134,22 @@ impl Handler {
for (room_id, room_info) in &response.rooms.invite { for (room_id, room_info) in &response.rooms.invite {
if let Some(room) = self.get_room(room_id) { if let Some(room) = self.get_room(room_id) {
for event in room_info for event in
.invite_state room_info.invite_state.events.iter().filter_map(|e| e.deserialize().ok())
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_stripped_state_event(room.clone(), &event).await; self.handle_stripped_state_event(room.clone(), &event).await;
} }
} }
} }
for event in response for event in response.presence.events.iter().filter_map(|e| e.deserialize().ok()) {
.presence
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
self.on_presence_event(&event).await; self.on_presence_event(&event).await;
} }
for (room_id, notifications) in &response.notifications { for (room_id, notifications) in &response.notifications {
if let Some(room) = self.get_room(&room_id) { if let Some(room) = self.get_room(&room_id) {
for notification in notifications { for notification in notifications {
self.on_room_notification(room.clone(), notification.clone()) self.on_room_notification(room.clone(), notification.clone()).await;
.await;
} }
} }
} }
@ -248,8 +213,7 @@ impl Handler {
self.on_room_tombstone(room, &tomb).await self.on_room_tombstone(room, &tomb).await
} }
AnySyncStateEvent::Custom(custom) => { AnySyncStateEvent::Custom(custom) => {
self.on_custom_event(room, &CustomEvent::State(custom)) self.on_custom_event(room, &CustomEvent::State(custom)).await
.await
} }
_ => {} _ => {}
} }
@ -267,8 +231,7 @@ impl Handler {
} }
AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, &name).await, AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, &name).await,
AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => { AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => {
self.on_stripped_state_canonical_alias(room, &canonical) self.on_stripped_state_canonical_alias(room, &canonical).await
.await
} }
AnyStrippedStateEvent::RoomAliases(aliases) => { AnyStrippedStateEvent::RoomAliases(aliases) => {
self.on_stripped_state_aliases(room, &aliases).await self.on_stripped_state_aliases(room, &aliases).await
@ -340,8 +303,9 @@ pub enum CustomEvent<'c> {
StrippedState(&'c StrippedStateEvent<CustomEventContent>), StrippedState(&'c StrippedStateEvent<CustomEventContent>),
} }
/// This trait allows any type implementing `EventHandler` to specify event callbacks for each /// This trait allows any type implementing `EventHandler` to specify event
/// event. The `Client` calls each method when the corresponding event is received. /// callbacks for each event. The `Client` calls each method when the
/// corresponding event is received.
/// ///
/// # Examples /// # Examples
/// ``` /// ```
@ -426,8 +390,8 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `RoomEvent::Tombstone` event. /// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {} async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {}
/// Fires when `Client` receives room events that trigger notifications according to /// Fires when `Client` receives room events that trigger notifications
/// the push rules of the user. /// according to the push rules of the user.
async fn on_room_notification(&self, _: Room, _: Notification) {} async fn on_room_notification(&self, _: Room, _: Notification) {}
// `RoomEvent`s from `IncomingState` // `RoomEvent`s from `IncomingState`
@ -452,7 +416,8 @@ pub trait EventHandler: Send + Sync {
async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {} async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
// `AnyStrippedStateEvent`s // `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, &self,
_: Room, _: Room,
@ -460,32 +425,38 @@ pub trait EventHandler: Send + Sync {
_: Option<MemberEventContent>, _: Option<MemberEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName`
/// event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {} async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias( async fn on_stripped_state_canonical_alias(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>, _: &StrippedStateEvent<CanonicalAliasEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases( async fn on_stripped_state_aliases(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<AliasesEventContent>, _: &StrippedStateEvent<AliasesEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {} async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels( async fn on_stripped_state_power_levels(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<PowerLevelsEventContent>, _: &StrippedStateEvent<PowerLevelsEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules( async fn on_stripped_state_join_rules(
&self, &self,
_: Room, _: Room,
@ -522,17 +493,18 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event. /// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event.
async fn on_presence_event(&self, _: &PresenceEvent) {} async fn on_presence_event(&self, _: &PresenceEvent) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails /// Fires when `Client` receives a `Event::Custom` event or if
/// because the event was unknown to ruma. /// deserialization fails because the event was unknown to ruma.
/// ///
/// The only guarantee this method can give about the event is that it is valid JSON. /// The only guarantee this method can give about the event is that it is
/// valid JSON.
async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {} async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails /// Fires when `Client` receives a `Event::Custom` event or if
/// because the event was unknown to ruma. /// deserialization fails because the event was unknown to ruma.
/// ///
/// The only guarantee this method can give about the event is that it is in the /// The only guarantee this method can give about the event is that it is in
/// shape of a valid matrix event. /// the shape of a valid matrix event.
async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {} async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {}
} }
@ -640,57 +612,50 @@ mod test {
} }
// `AnyStrippedStateEvent`s // `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<MemberEventContent>, _: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>, _: Option<MemberEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state member".to_string())
.lock()
.await
.push("stripped state member".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomName` event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) { async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {
self.0.lock().await.push("stripped state name".to_string()) self.0.lock().await.push("stripped state name".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` /// Fires when `Client` receives a
/// event. /// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias( async fn on_stripped_state_canonical_alias(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>, _: &StrippedStateEvent<CanonicalAliasEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state canonical".to_string())
.lock()
.await
.push("stripped state canonical".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases( async fn on_stripped_state_aliases(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<AliasesEventContent>, _: &StrippedStateEvent<AliasesEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state aliases".to_string())
.lock()
.await
.push("stripped state aliases".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar( async fn on_stripped_state_avatar(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<AvatarEventContent>, _: &StrippedStateEvent<AvatarEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state avatar".to_string())
.lock()
.await
.push("stripped state avatar".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels( async fn on_stripped_state_power_levels(
&self, &self,
_: Room, _: Room,
@ -698,7 +663,8 @@ mod test {
) { ) {
self.0.lock().await.push("stripped state power".to_string()) self.0.lock().await.push("stripped state power".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules( async fn on_stripped_state_join_rules(
&self, &self,
_: Room, _: Room,
@ -769,10 +735,7 @@ mod test {
} }
async fn mock_sync(client: &Client, response: String) { async fn mock_sync(client: &Client, response: String) {
let _m = mock( let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()))
"GET",
Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()),
)
.with_status(200) .with_status(200)
.match_header("authorization", "Bearer 1234") .match_header("authorization", "Bearer 1234")
.with_body(response) .with_body(response)
@ -824,14 +787,7 @@ mod test {
mock_sync(&client, test_json::INVITE_SYNC.to_string()).await; mock_sync(&client, test_json::INVITE_SYNC.to_string()).await;
let v = test_vec.lock().await; let v = test_vec.lock().await;
assert_eq!( assert_eq!(v.as_slice(), ["stripped state name", "stripped state member", "presence event"],)
v.as_slice(),
[
"stripped state name",
"stripped state member",
"presence event"
],
)
} }
#[async_test] #[async_test]
@ -898,15 +854,7 @@ mod test {
mock_sync(&client, test_json::VOIP_SYNC.to_string()).await; mock_sync(&client, test_json::VOIP_SYNC.to_string()).await;
let v = test_vec.lock().await; let v = test_vec.lock().await;
assert_eq!( assert_eq!(v.as_slice(), ["call invite", "call answer", "call candidates", "call hangup",],)
v.as_slice(),
[
"call invite",
"call answer",
"call candidates",
"call hangup",
],
)
} }
#[async_test] #[async_test]

View File

@ -121,8 +121,7 @@ impl HttpClient {
let request = if !self.request_config.assert_identity { let request = if !self.request_config.assert_identity {
self.try_into_http_request(request, session, config).await? self.try_into_http_request(request, session, config).await?
} else { } else {
self.try_into_http_request_with_identy_assertion(request, session, config) self.try_into_http_request_with_identy_assertion(request, session, config).await?
.await?
}; };
self.inner.send_request(request, config).await self.inner.send_request(request, config).await
@ -201,9 +200,7 @@ impl HttpClient {
request: create_content::Request<'_>, request: create_content::Request<'_>,
config: Option<RequestConfig>, config: Option<RequestConfig>,
) -> Result<create_content::Response, HttpError> { ) -> Result<create_content::Response, HttpError> {
let response = self let response = self.send_request(request, self.session.clone(), config).await?;
.send_request(request, self.session.clone(), config)
.await?;
Ok(create_content::Response::try_from_http_response(response)?) Ok(create_content::Response::try_from_http_response(response)?)
} }
@ -216,9 +213,7 @@ impl HttpClient {
Request: OutgoingRequest + Debug, Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>, HttpError: From<FromHttpResponseError<Request::EndpointError>>,
{ {
let response = self let response = self.send_request(request, self.session.clone(), config).await?;
.send_request(request, self.session.clone(), config)
.await?;
trace!("Got response: {:?}", response); trace!("Got response: {:?}", response);
@ -255,9 +250,7 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result<Client, HttpEr
headers.insert(reqwest::header::USER_AGENT, user_agent); headers.insert(reqwest::header::USER_AGENT, user_agent);
http_client http_client.default_headers(headers).timeout(config.request_config.timeout)
.default_headers(headers)
.timeout(config.request_config.timeout)
}; };
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
@ -273,9 +266,7 @@ async fn response_to_http_response(
let status = response.status(); let status = response.status();
let mut http_builder = HttpResponse::builder().status(status); let mut http_builder = HttpResponse::builder().status(status);
let headers = http_builder let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
.headers_mut()
.expect("Can't get the response builder headers");
for (k, v) in response.headers_mut().drain() { for (k, v) in response.headers_mut().drain() {
if let Some(key) = k { if let Some(key) = k {
@ -285,9 +276,7 @@ async fn response_to_http_response(
let body = response.bytes().await?; let body = response.bytes().await?;
Ok(http_builder Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
.body(body)
.expect("Can't construct a response using the given body"))
} }
#[cfg(any(target_arch = "wasm32"))] #[cfg(any(target_arch = "wasm32"))]
@ -328,18 +317,12 @@ async fn send_request(
}; };
// Turn errors into permanent errors when the retry limit is reached // Turn errors into permanent errors when the retry limit is reached
let error_type = if stop { let error_type = if stop { RetryError::Permanent } else { RetryError::Transient };
RetryError::Permanent
} else {
RetryError::Transient
};
let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?; let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?;
let response = client let response =
.execute(request) client.execute(request).await.map_err(|e| error_type(HttpError::Reqwest(e)))?;
.await
.map_err(|e| error_type(HttpError::Reqwest(e)))?;
let status_code = response.status(); let status_code = response.status();
// TODO TOO_MANY_REQUESTS will have a retry timeout which we should // TODO TOO_MANY_REQUESTS will have a retry timeout which we should

View File

@ -36,10 +36,7 @@ impl Common {
/// * `room` - The underlaying room. /// * `room` - The underlaying room.
pub fn new(client: Client, room: BaseRoom) -> Self { pub fn new(client: Client, room: BaseRoom) -> Self {
// TODO: Make this private // TODO: Make this private
Self { Self { inner: room, client }
inner: room,
client,
}
} }
/// Leave this room. /// Leave this room.
@ -152,35 +149,25 @@ impl Common {
pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> { pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
if let Some(mutex) = self if let Some(mutex) =
.client self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
.members_request_locks
.get(self.inner.room_id())
.map(|m| m.clone())
{ {
mutex.lock().await; mutex.lock().await;
Ok(None) Ok(None)
} else { } else {
let mutex = Arc::new(Mutex::new(())); let mutex = Arc::new(Mutex::new(()));
self.client self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone());
.members_request_locks
.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await; let _guard = mutex.lock().await;
let request = get_member_events::Request::new(self.inner.room_id()); let request = get_member_events::Request::new(self.inner.room_id());
let response = self.client.send(request, None).await?; let response = self.client.send(request, None).await?;
let response = self let response =
.client self.client.base_client.receive_members(self.inner.room_id(), &response).await?;
.base_client
.receive_members(self.inner.room_id(), &response)
.await?;
self.client self.client.members_request_locks.remove(self.inner.room_id());
.members_request_locks
.remove(self.inner.room_id());
Ok(Some(response)) Ok(Some(response))
} }

View File

@ -21,9 +21,7 @@ impl Invited {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> { pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private // TODO: Make this private
if room.room_type() == RoomType::Invited { if room.room_type() == RoomType::Invited {
Some(Self { Some(Self { inner: Common::new(client, room) })
inner: Common::new(client, room),
})
} else { } else {
None None
} }

View File

@ -47,8 +47,9 @@ const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3);
/// A room in the joined state. /// A room in the joined state.
/// ///
/// The `JoinedRoom` contains all methodes specific to a `Room` with type `RoomType::Joined`. /// The `JoinedRoom` contains all methodes specific to a `Room` with type
/// Operations may fail once the underlaying `Room` changes `RoomType`. /// `RoomType::Joined`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Joined { pub struct Joined {
pub(crate) inner: Common, pub(crate) inner: Common,
@ -63,7 +64,8 @@ impl Deref for Joined {
} }
impl Joined { impl Joined {
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type `RoomType::Joined`. /// Create a new `room::Joined` if the underlaying `BaseRoom` has type
/// `RoomType::Joined`.
/// ///
/// # Arguments /// # Arguments
/// * `client` - The client used to make requests. /// * `client` - The client used to make requests.
@ -72,9 +74,7 @@ impl Joined {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> { pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private // TODO: Make this private
if room.room_type() == RoomType::Joined { if room.room_type() == RoomType::Joined {
Some(Self { Some(Self { inner: Common::new(client, room) })
inner: Common::new(client, room),
})
} else { } else {
None None
} }
@ -93,9 +93,7 @@ impl Joined {
/// ///
/// * `reason` - The reason for banning this user. /// * `reason` - The reason for banning this user.
pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> { pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), { let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), { reason });
reason
});
self.client.send(request, None).await?; self.client.send(request, None).await?;
Ok(()) Ok(())
} }
@ -104,13 +102,12 @@ impl Joined {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `user_id` - The `UserId` of the user that should be kicked out of the room. /// * `user_id` - The `UserId` of the user that should be kicked out of the
/// room.
/// ///
/// * `reason` - Optional reason why the room member is being kicked out. /// * `reason` - Optional reason why the room member is being kicked out.
pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> { pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), { let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), { reason });
reason
});
self.client.send(request, None).await?; self.client.send(request, None).await?;
Ok(()) Ok(())
} }
@ -144,10 +141,11 @@ impl Joined {
/// Activate typing notice for this room. /// Activate typing notice for this room.
/// ///
/// The typing notice remains active for 4s. It can be deactivate at any point by setting /// The typing notice remains active for 4s. It can be deactivate at any
/// typing to `false`. If this method is called while the typing notice is active nothing will /// point by setting typing to `false`. If this method is called while
/// happen. This method can be called on every key stroke, since it will do nothing while /// the typing notice is active nothing will happen. This method can be
/// typing is active. /// called on every key stroke, since it will do nothing while typing is
/// active.
/// ///
/// # Arguments /// # Arguments
/// ///
@ -179,21 +177,23 @@ impl Joined {
/// # }); /// # });
/// ``` /// ```
pub async fn typing_notice(&self, typing: bool) -> Result<()> { pub async fn typing_notice(&self, typing: bool) -> Result<()> {
// Only send a request to the homeserver if the old timeout has elapsed or the typing // Only send a request to the homeserver if the old timeout has elapsed
// notice changed state within the TYPING_NOTICE_TIMEOUT // or the typing notice changed state within the
// TYPING_NOTICE_TIMEOUT
let send = let send =
if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) { if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) {
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT { if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
// We always reactivate the typing notice if typing is true or we may need to // We always reactivate the typing notice if typing is true or
// deactivate it if it's currently active if typing is false // we may need to deactivate it if it's
// currently active if typing is false
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
} else { } else {
// Only send a request when we need to deactivate typing // Only send a request when we need to deactivate typing
!typing !typing
} }
} else { } else {
// Typing notice is currently deactivated, therefore, send a request only when it's // Typing notice is currently deactivated, therefore, send a request
// about to be activated // only when it's about to be activated
typing typing
}; };
@ -216,11 +216,13 @@ impl Joined {
Ok(()) Ok(())
} }
/// Send a request to notify this room that the user has read specific event. /// Send a request to notify this room that the user has read specific
/// event.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `event_id` - The `EventId` specifies the event to set the read receipt on. /// * `event_id` - The `EventId` specifies the event to set the read receipt
/// on.
pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> { pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> {
let request = let request =
create_receipt::Request::new(self.inner.room_id(), ReceiptType::Read, event_id); create_receipt::Request::new(self.inner.room_id(), ReceiptType::Read, event_id);
@ -229,22 +231,23 @@ impl Joined {
Ok(()) Ok(())
} }
/// Send a request to notify this room that the user has read up to specific event. /// Send a request to notify this room that the user has read up to specific
/// event.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * fully_read - The `EventId` of the event the user has read to. /// * fully_read - The `EventId` of the event the user has read to.
/// ///
/// * read_receipt - An `EventId` to specify the event to set the read receipt on. /// * read_receipt - An `EventId` to specify the event to set the read
/// receipt on.
pub async fn read_marker( pub async fn read_marker(
&self, &self,
fully_read: &EventId, fully_read: &EventId,
read_receipt: Option<&EventId>, read_receipt: Option<&EventId>,
) -> Result<()> { ) -> Result<()> {
let request = assign!( let request = assign!(set_read_marker::Request::new(self.inner.room_id(), fully_read), {
set_read_marker::Request::new(self.inner.room_id(), fully_read), read_receipt
{ read_receipt } });
);
self.client.send(request, None).await?; self.client.send(request, None).await?;
Ok(()) Ok(())
@ -262,11 +265,8 @@ impl Joined {
// TODO expose this publicly so people can pre-share a group session if // TODO expose this publicly so people can pre-share a group session if
// e.g. a user starts to type a message for a room. // e.g. a user starts to type a message for a room.
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
if let Some(mutex) = self if let Some(mutex) =
.client self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
.group_session_locks
.get(self.inner.room_id())
.map(|m| m.clone())
{ {
// If a group session share request is already going on, // If a group session share request is already going on,
// await the release of the lock. // await the release of the lock.
@ -275,23 +275,14 @@ impl Joined {
// Otherwise create a new lock and share the group // Otherwise create a new lock and share the group
// session. // session.
let mutex = Arc::new(Mutex::new(())); let mutex = Arc::new(Mutex::new(()));
self.client self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone());
.group_session_locks
.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await; let _guard = mutex.lock().await;
{ {
let joined = self let joined = self.client.store().get_joined_user_ids(self.inner.room_id()).await?;
.client let invited =
.store() self.client.store().get_invited_user_ids(self.inner.room_id()).await?;
.get_joined_user_ids(self.inner.room_id())
.await?;
let invited = self
.client
.store()
.get_invited_user_ids(self.inner.room_id())
.await?;
let members = joined.iter().chain(&invited); let members = joined.iter().chain(&invited);
self.client.claim_one_time_keys(members).await?; self.client.claim_one_time_keys(members).await?;
}; };
@ -304,10 +295,7 @@ impl Joined {
// session as using it would end up in undecryptable // session as using it would end up in undecryptable
// messages. // messages.
if let Err(r) = response { if let Err(r) = response {
self.client self.client.base_client.invalidate_group_session(self.inner.room_id()).await?;
.base_client
.invalidate_group_session(self.inner.room_id())
.await?;
return Err(r); return Err(r);
} }
} }
@ -324,19 +312,13 @@ impl Joined {
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[instrument] #[instrument]
async fn share_group_session(&self) -> Result<()> { async fn share_group_session(&self) -> Result<()> {
let mut requests = self let mut requests =
.client self.client.base_client.share_group_session(self.inner.room_id()).await?;
.base_client
.share_group_session(self.inner.room_id())
.await?;
for request in requests.drain(..) { for request in requests.drain(..) {
let response = self.client.send_to_device(&request).await?; let response = self.client.send_to_device(&request).await?;
self.client self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?;
.base_client
.mark_request_as_sent(&request.txn_id, &response)
.await?;
} }
Ok(()) Ok(())
@ -403,10 +385,7 @@ impl Joined {
self.preshare_group_session().await?; self.preshare_group_session().await?;
AnyMessageEventContent::RoomEncrypted( AnyMessageEventContent::RoomEncrypted(
self.client self.client.base_client.encrypt(self.inner.room_id(), content).await?,
.base_client
.encrypt(self.inner.room_id(), content)
.await?,
) )
} else { } else {
content.into() content.into()
@ -426,8 +405,9 @@ impl Joined {
/// If the room is encrypted and the encryption feature is enabled the /// If the room is encrypted and the encryption feature is enabled the
/// upload will be encrypted. /// upload will be encrypted.
/// ///
/// This is a convenience method that calls the [`Client::upload()`](#Client::method.upload) /// This is a convenience method that calls the
/// and afterwards the [`send()`](#method.send). /// [`Client::upload()`](#Client::method.upload) and afterwards the
/// [`send()`](#method.send).
/// ///
/// # Arguments /// # Arguments
/// * `body` - A textual representation of the media that is going to be /// * `body` - A textual representation of the media that is going to be
@ -534,10 +514,7 @@ impl Joined {
}), }),
}; };
self.send( self.send(AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), txn_id)
AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)),
txn_id,
)
.await .await
} }
@ -635,10 +612,10 @@ impl Joined {
txn_id: Option<Uuid>, txn_id: Option<Uuid>,
) -> Result<redact_event::Response> { ) -> Result<redact_event::Response> {
let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string(); let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string();
let request = assign!( let request =
redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), {
{ reason } reason
); });
self.client.send(request, None).await self.client.send(request, None).await
} }

View File

@ -23,9 +23,7 @@ impl Left {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> { pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private // TODO: Make this private
if room.room_type() == RoomType::Left { if room.room_type() == RoomType::Left {
Some(Self { Some(Self { inner: Common::new(client, room) })
inner: Common::new(client, room),
})
} else { } else {
None None
} }

View File

@ -21,10 +21,7 @@ impl Deref for RoomMember {
impl RoomMember { impl RoomMember {
pub(crate) fn new(client: Client, member: BaseRoomMember) -> Self { pub(crate) fn new(client: Client, member: BaseRoomMember) -> Self {
Self { Self { inner: member, client }
inner: member,
client,
}
} }
/// Gets the avatar of this member, if set. /// Gets the avatar of this member, if set.

View File

@ -26,11 +26,7 @@ impl AppserviceEventHandler {
#[async_trait] #[async_trait]
impl EventHandler for AppserviceEventHandler { impl EventHandler for AppserviceEventHandler {
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) { async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
if !self if !self.appservice.user_id_is_in_namespace(&event.state_key).unwrap() {
.appservice
.user_id_is_in_namespace(&event.state_key)
.unwrap()
{
dbg!("not an appservice user"); dbg!("not an appservice user");
return; return;
} }
@ -38,11 +34,7 @@ impl EventHandler for AppserviceEventHandler {
if let MembershipState::Invite = event.content.membership { if let MembershipState::Invite = event.content.membership {
let user_id = UserId::try_from(event.state_key.clone()).unwrap(); let user_id = UserId::try_from(event.state_key.clone()).unwrap();
let client = self let client = self.appservice.client_with_localpart(user_id.localpart()).await.unwrap();
.appservice
.client_with_localpart(user_id.localpart())
.await
.unwrap();
client.join_room_by_id(room.room_id()).await.unwrap(); client.join_room_by_id(room.room_id()).await.unwrap();
} }
@ -51,10 +43,7 @@ impl EventHandler for AppserviceEventHandler {
#[actix_web::main] #[actix_web::main]
pub async fn main() -> std::io::Result<()> { pub async fn main() -> std::io::Result<()> {
env::set_var( env::set_var("RUST_LOG", "actix_web=debug,actix_server=info,matrix_sdk=debug");
"RUST_LOG",
"actix_web=debug,actix_server=info,matrix_sdk=debug",
);
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let homeserver_url = "http://localhost:8008"; let homeserver_url = "http://localhost:8008";
@ -62,16 +51,11 @@ pub async fn main() -> std::io::Result<()> {
let registration = let registration =
AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap(); AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap();
let appservice = Appservice::new(homeserver_url, server_name, registration) let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap();
.await
.unwrap();
let event_handler = AppserviceEventHandler::new(appservice.clone()); let event_handler = AppserviceEventHandler::new(appservice.clone());
appservice appservice.client().set_event_handler(Box::new(event_handler)).await;
.client()
.set_event_handler(Box::new(event_handler))
.await;
HttpServer::new(move || App::new().service(appservice.actix_service())) HttpServer::new(move || App::new().service(appservice.actix_service()))
.bind(("0.0.0.0", 8090))? .bind(("0.0.0.0", 8090))?

View File

@ -52,10 +52,7 @@ pub fn get_scope() -> Scope {
} }
fn gen_scope(scope: &str) -> Scope { fn gen_scope(scope: &str) -> Scope {
web::scope(scope) web::scope(scope).service(push_transactions).service(query_user_id).service(query_room_alias)
.service(push_transactions)
.service(query_user_id)
.service(query_room_alias)
} }
#[tracing::instrument] #[tracing::instrument]
@ -68,11 +65,7 @@ async fn push_transactions(
return Ok(HttpResponse::Unauthorized().finish()); return Ok(HttpResponse::Unauthorized().finish());
} }
appservice appservice.client().receive_transaction(request.incoming).await.unwrap();
.client()
.receive_transaction(request.incoming)
.await
.unwrap();
Ok(HttpResponse::Ok().json("{}")) Ok(HttpResponse::Ok().json("{}"))
} }
@ -135,13 +128,9 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
uri uri
}; };
let mut builder = http::request::Builder::new() let mut builder = http::request::Builder::new().method(request.method()).uri(uri);
.method(request.method())
.uri(uri);
let headers = builder let headers = builder.headers_mut().ok_or(Error::UnknownHttpRequestBuilder)?;
.headers_mut()
.ok_or(Error::UnknownHttpRequestBuilder)?;
for (key, value) in request.headers().iter() { for (key, value) in request.headers().iter() {
headers.append(key, value.to_owned()); headers.append(key, value.to_owned());
} }
@ -157,10 +146,7 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
let access_token = match request.uri().query() { let access_token = match request.uri().query() {
Some(query) => { Some(query) => {
let query: Vec<(String, String)> = matrix_sdk::urlencoded::from_str(query)?; let query: Vec<(String, String)> = matrix_sdk::urlencoded::from_str(query)?;
query query.into_iter().find(|(key, _)| key == "access_token").map(|(_, value)| value)
.into_iter()
.find(|(key, _)| key == "access_token")
.map(|(_, value)| value)
} }
None => None, None => None,
}; };

View File

@ -104,9 +104,7 @@ impl AppserviceRegistration {
/// ///
/// See the fields of [`Registration`] for the required format /// See the fields of [`Registration`] for the required format
pub fn try_from_yaml_str(value: impl AsRef<str>) -> Result<Self> { pub fn try_from_yaml_str(value: impl AsRef<str>) -> Result<Self> {
Ok(Self { Ok(Self { inner: serde_yaml::from_str(value.as_ref())? })
inner: serde_yaml::from_str(value.as_ref())?,
})
} }
/// Try to load registration from yaml file /// Try to load registration from yaml file
@ -115,9 +113,7 @@ impl AppserviceRegistration {
pub fn try_from_yaml_file(path: impl Into<PathBuf>) -> Result<Self> { pub fn try_from_yaml_file(path: impl Into<PathBuf>) -> Result<Self> {
let file = File::open(path.into())?; let file = File::open(path.into())?;
Ok(Self { Ok(Self { inner: serde_yaml::from_reader(file)? })
inner: serde_yaml::from_reader(file)?,
})
} }
} }

View File

@ -6,10 +6,7 @@ mod actix {
use matrix_sdk_appservice::*; use matrix_sdk_appservice::*;
async fn appservice() -> Appservice { async fn appservice() -> Appservice {
env::set_var( env::set_var("RUST_LOG", "mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug");
"RUST_LOG",
"mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug",
);
let _ = tracing_subscriber::fmt::try_init(); let _ = tracing_subscriber::fmt::try_init();
Appservice::new( Appservice::new(

View File

@ -76,10 +76,7 @@ async fn test_event_handler() -> Result<()> {
} }
} }
appservice appservice.client().set_event_handler(Box::new(Example::new())).await;
.client()
.set_event_handler(Box::new(Example::new()))
.await;
let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap(); let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap();
let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into(); let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into();

View File

@ -76,17 +76,8 @@ impl InspectorHelper {
fn complete_event_types(&self, arg: Option<&&str>) -> Vec<Pair> { fn complete_event_types(&self, arg: Option<&&str>) -> Vec<Pair> {
Self::EVENT_TYPES Self::EVENT_TYPES
.iter() .iter()
.map(|t| Pair { .map(|t| Pair { display: t.to_string(), replacement: format!("{} ", t) })
display: t.to_string(), .filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
replacement: format!("{} ", t),
})
.filter(|r| {
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.collect() .collect()
} }
@ -99,13 +90,7 @@ impl InspectorHelper {
display: r.room_id.to_string(), display: r.room_id.to_string(),
replacement: format!("{} ", r.room_id.to_string()), replacement: format!("{} ", r.room_id.to_string()),
}) })
.filter(|r| { .filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.collect() .collect()
} }
} }
@ -124,15 +109,9 @@ impl Completer for InspectorHelper {
let commands = vec![ let commands = vec![
("get-state", "get a state event in the given room"), ("get-state", "get a state event in the given room"),
( ("get-profiles", "get all the stored profiles in the given room"),
"get-profiles",
"get all the stored profiles in the given room",
),
("list-rooms", "list all rooms"), ("list-rooms", "list all rooms"),
( ("get-members", "get all the membership events in the given room"),
"get-members",
"get all the membership events in the given room",
),
] ]
.iter() .iter()
.map(|(r, d)| Pair { .map(|(r, d)| Pair {
@ -151,19 +130,13 @@ impl Completer for InspectorHelper {
} else { } else {
Ok(( Ok((
0, 0,
commands commands.into_iter().filter(|c| c.replacement.starts_with(args[0])).collect(),
.into_iter()
.filter(|c| c.replacement.starts_with(args[0]))
.collect(),
)) ))
} }
} else if args.len() == 2 { } else if args.len() == 2 {
if args[0] == "get-state" { if args[0] == "get-state" {
if line.ends_with(' ') { if line.ends_with(' ') {
Ok(( Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
} else { } else {
Ok((args[0].len() + 1, self.complete_rooms(args.get(1)))) Ok((args[0].len() + 1, self.complete_rooms(args.get(1))))
} }
@ -174,10 +147,7 @@ impl Completer for InspectorHelper {
} }
} else if args.len() == 3 { } else if args.len() == 3 {
if args[0] == "get-state" { if args[0] == "get-state" {
Ok(( Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
} else { } else {
Ok((pos, vec![])) Ok((pos, vec![]))
} }
@ -213,12 +183,7 @@ impl Printer {
let syntax_set: SyntaxSet = from_binary(include_bytes!("./syntaxes.bin")); let syntax_set: SyntaxSet = from_binary(include_bytes!("./syntaxes.bin"));
let themes: ThemeSet = from_binary(include_bytes!("./themes.bin")); let themes: ThemeSet = from_binary(include_bytes!("./themes.bin"));
Self { Self { ps: syntax_set.into(), ts: themes.into(), json, color }
ps: syntax_set.into(),
ts: themes.into(),
json,
color,
}
} }
fn pretty_print_struct<T: Debug + Serialize>(&self, data: &T) { fn pretty_print_struct<T: Debug + Serialize>(&self, data: &T) {
@ -229,13 +194,9 @@ impl Printer {
}; };
let syntax = if self.json { let syntax = if self.json {
self.ps self.ps.find_syntax_by_extension("rs").expect("Can't find rust syntax extension")
.find_syntax_by_extension("rs")
.expect("Can't find rust syntax extension")
} else { } else {
self.ps self.ps.find_syntax_by_extension("json").expect("Can't find json syntax extension")
.find_syntax_by_extension("json")
.expect("Can't find json syntax extension")
}; };
if self.color { if self.color {
@ -302,11 +263,7 @@ impl Inspector {
} }
async fn get_display_name_owners(&self, room_id: RoomId, display_name: String) { async fn get_display_name_owners(&self, room_id: RoomId, display_name: String) {
let users = self let users = self.store.get_users_with_display_name(&room_id, &display_name).await.unwrap();
.store
.get_users_with_display_name(&room_id, &display_name)
.await
.unwrap();
self.printer.pretty_print_struct(&users); self.printer.pretty_print_struct(&users);
} }
@ -323,22 +280,14 @@ impl Inspector {
let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap(); let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
for member in joined { for member in joined {
let event = self let event = self.store.get_member_event(&room_id, &member).await.unwrap();
.store
.get_member_event(&room_id, &member)
.await
.unwrap();
self.printer.pretty_print_struct(&event); self.printer.pretty_print_struct(&event);
} }
} }
async fn get_state(&self, room_id: RoomId, event_type: EventType) { async fn get_state(&self, room_id: RoomId, event_type: EventType) {
self.printer.pretty_print_struct( self.printer.pretty_print_struct(
&self &self.store.get_state_event(&room_id, event_type, "").await.unwrap(),
.store
.get_state_event(&room_id, event_type, "")
.await
.unwrap(),
); );
} }
@ -347,35 +296,25 @@ impl Inspector {
SubCommand::with_name("list-rooms"), SubCommand::with_name("list-rooms"),
SubCommand::with_name("get-members").arg( SubCommand::with_name("get-members").arg(
Arg::with_name("room-id").required(true).validator(|r| { Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
}), }),
), ),
SubCommand::with_name("get-profiles").arg( SubCommand::with_name("get-profiles").arg(
Arg::with_name("room-id").required(true).validator(|r| { Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
}), }),
), ),
SubCommand::with_name("get-display-names") SubCommand::with_name("get-display-names")
.arg(Arg::with_name("room-id").required(true).validator(|r| { .arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
})) }))
.arg(Arg::with_name("display-name").required(true)), .arg(Arg::with_name("display-name").required(true)),
SubCommand::with_name("get-state") SubCommand::with_name("get-state")
.arg(Arg::with_name("room-id").required(true).validator(|r| { .arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
})) }))
.arg(Arg::with_name("event-type").required(true).validator(|e| { .arg(Arg::with_name("event-type").required(true).validator(|e| {
EventType::try_from(e) EventType::try_from(e).map(|_| ()).map_err(|_| "Invalid event type".to_string())
.map(|_| ())
.map_err(|_| "Invalid event type".to_string())
})), })),
] ]
} }

View File

@ -87,20 +87,21 @@ pub struct AdditionalUnsignedData {
pub prev_content: Option<Raw<MemberEventContent>>, pub prev_content: Option<Raw<MemberEventContent>>,
} }
/// Transform state event by hoisting `prev_content` field from `unsigned` to the top level. /// Transform state event by hoisting `prev_content` field from `unsigned` to
/// the top level.
/// ///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in `unsigned` contrary to /// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// the C2S spec. Some more discussion can be found [here][discussion]. Until this is fixed in /// `unsigned` contrary to the C2S spec. Some more discussion can be found
/// synapse or handled in Ruma, we use this to hoist up `prev_content` to the top level. /// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
/// use this to hoist up `prev_content` to the top level.
/// ///
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668> /// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668> /// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
pub fn hoist_and_deserialize_state_event( pub fn hoist_and_deserialize_state_event(
event: &Raw<AnySyncStateEvent>, event: &Raw<AnySyncStateEvent>,
) -> StdResult<AnySyncStateEvent, serde_json::Error> { ) -> StdResult<AnySyncStateEvent, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())? let prev_content =
.unsigned serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
.prev_content;
let mut ev = event.deserialize()?; let mut ev = event.deserialize()?;
@ -116,9 +117,8 @@ pub fn hoist_and_deserialize_state_event(
fn hoist_member_event( fn hoist_member_event(
event: &Raw<StateEvent<MemberEventContent>>, event: &Raw<StateEvent<MemberEventContent>>,
) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> { ) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())? let prev_content =
.unsigned serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
.prev_content;
let mut e = event.deserialize()?; let mut e = event.deserialize()?;
@ -340,7 +340,8 @@ impl BaseClient {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `response` - A successful login response that contains our access token /// * `response` - A successful login response that contains our access
/// token
/// and device id. /// and device id.
pub async fn receive_login_response( pub async fn receive_login_response(
&self, &self,
@ -440,9 +441,7 @@ impl BaseClient {
AnySyncRoomEvent::State(s) => match s { AnySyncRoomEvent::State(s) => match s {
AnySyncStateEvent::RoomMember(member) => { AnySyncStateEvent::RoomMember(member) => {
if let Ok(member) = MemberEvent::try_from(member.clone()) { if let Ok(member) = MemberEvent::try_from(member.clone()) {
ambiguity_cache ambiguity_cache.handle_event(changes, room_id, &member).await?;
.handle_event(changes, room_id, &member)
.await?;
match member.content.membership { match member.content.membership {
MembershipState::Join | MembershipState::Invite => { MembershipState::Join | MembershipState::Invite => {
@ -500,8 +499,7 @@ impl BaseClient {
} }
if let Some(context) = &mut push_context { if let Some(context) = &mut push_context {
self.update_push_room_context(context, user_id, room_info, changes) self.update_push_room_context(context, user_id, room_info, changes).await;
.await;
} else { } else {
push_context = self.get_push_room_context(room, room_info, changes).await?; push_context = self.get_push_room_context(room, room_info, changes).await?;
} }
@ -521,10 +519,13 @@ impl BaseClient {
), ),
); );
} }
// TODO if there is an Action::SetTweak(Tweak::Highlight) we need to store // TODO if there is an
// its value with the event so a client can show if the event is highlighted // Action::SetTweak(Tweak::Highlight) we need to store
// its value with the event so a client can show if the
// event is highlighted
// in the UI. // in the UI.
// Requires the possibility to associate custom data with events and to // Requires the possibility to associate custom data
// with events and to
// store them. // store them.
} }
} }
@ -762,18 +763,14 @@ impl BaseClient {
let mut changes = StateChanges::new(next_batch.clone()); let mut changes = StateChanges::new(next_batch.clone());
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone()); let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
self.handle_account_data(&account_data.events, &mut changes) self.handle_account_data(&account_data.events, &mut changes).await;
.await;
let push_rules = self.get_push_rules(&changes).await?; let push_rules = self.get_push_rules(&changes).await?;
let mut new_rooms = Rooms::default(); let mut new_rooms = Rooms::default();
for (room_id, new_info) in rooms.join { for (room_id, new_info) in rooms.join {
let room = self let room = self.store.get_or_create_room(&room_id, RoomType::Joined).await;
.store
.get_or_create_room(&room_id, RoomType::Joined)
.await;
let mut room_info = room.clone_info(); let mut room_info = room.clone_info();
room_info.mark_as_joined(); room_info.mark_as_joined();
@ -844,10 +841,7 @@ impl BaseClient {
} }
for (room_id, new_info) in rooms.leave { for (room_id, new_info) in rooms.leave {
let room = self let room = self.store.get_or_create_room(&room_id, RoomType::Left).await;
.store
.get_or_create_room(&room_id, RoomType::Left)
.await;
let mut room_info = room.clone_info(); let mut room_info = room.clone_info();
room_info.mark_as_left(); room_info.mark_as_left();
@ -876,18 +870,14 @@ impl BaseClient {
.await; .await;
changes.add_room(room_info); changes.add_room(room_info);
new_rooms.leave.insert( new_rooms
room_id, .leave
LeftRoom::new(timeline, new_info.state, new_info.account_data), .insert(room_id, LeftRoom::new(timeline, new_info.state, new_info.account_data));
);
} }
for (room_id, new_info) in rooms.invite { for (room_id, new_info) in rooms.invite {
{ {
let room = self let room = self.store.get_or_create_room(&room_id, RoomType::Invited).await;
.store
.get_or_create_room(&room_id, RoomType::Invited)
.await;
let mut room_info = room.clone_info(); let mut room_info = room.clone_info();
room_info.mark_as_invited(); room_info.mark_as_invited();
changes.add_room(room_info); changes.add_room(room_info);
@ -934,9 +924,7 @@ impl BaseClient {
.into_iter() .into_iter()
.map(|(k, v)| (k, v.into())) .map(|(k, v)| (k, v.into()))
.collect(), .collect(),
ambiguity_changes: AmbiguityChanges { ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
changes: ambiguity_cache.changes,
},
notifications: changes.notifications, notifications: changes.notifications,
}; };
@ -968,11 +956,7 @@ impl BaseClient {
let members: Vec<MemberEvent> = response let members: Vec<MemberEvent> = response
.chunk .chunk
.iter() .iter()
.filter_map(|e| { .filter_map(|e| hoist_member_event(e).ok().and_then(|e| MemberEvent::try_from(e).ok()))
hoist_member_event(e)
.ok()
.and_then(|e| MemberEvent::try_from(e).ok())
})
.collect(); .collect();
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone()); let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
@ -986,12 +970,7 @@ impl BaseClient {
let mut user_ids = BTreeSet::new(); let mut user_ids = BTreeSet::new();
for member in &members { for member in &members {
if self if self.store.get_member_event(&room_id, &member.state_key).await?.is_none() {
.store
.get_member_event(&room_id, &member.state_key)
.await?
.is_none()
{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
match member.content.membership { match member.content.membership {
MembershipState::Join | MembershipState::Invite => { MembershipState::Join | MembershipState::Invite => {
@ -1000,9 +979,7 @@ impl BaseClient {
_ => (), _ => (),
} }
ambiguity_cache ambiguity_cache.handle_event(&changes, room_id, &member).await?;
.handle_event(&changes, room_id, &member)
.await?;
if member.state_key == member.sender { if member.state_key == member.sender {
changes changes
@ -1036,9 +1013,7 @@ impl BaseClient {
Ok(MembersResponse { Ok(MembersResponse {
chunk: members, chunk: members,
ambiguity_changes: AmbiguityChanges { ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
changes: ambiguity_cache.changes,
},
}) })
} }
@ -1050,7 +1025,8 @@ impl BaseClient {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `filter_name` - The name that should be used to persist the filter id in /// * `filter_name` - The name that should be used to persist the filter id
/// in
/// the store. /// the store.
/// ///
/// * `response` - The successful filter upload response containing the /// * `response` - The successful filter upload response containing the
@ -1062,10 +1038,7 @@ impl BaseClient {
filter_name: &str, filter_name: &str,
response: &api::filter::create_filter::Response, response: &api::filter::create_filter::Response,
) -> Result<()> { ) -> Result<()> {
Ok(self Ok(self.store.save_filter(filter_name, &response.filter_id).await?)
.store
.save_filter(filter_name, &response.filter_id)
.await?)
} }
/// Get the filter id of a previously uploaded filter. /// Get the filter id of a previously uploaded filter.
@ -1223,18 +1196,15 @@ impl BaseClient {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `flow_id` - The unique id that identifies a interactive verification flow. For in-room /// * `flow_id` - The unique id that identifies a interactive verification
/// verifications this will be the event id of the *m.key.verification.request* event that /// flow. For in-room verifications this will be the event id of the
/// started the flow, for the to-device verification flows this will be the transaction id of /// *m.key.verification.request* event that started the flow, for the
/// the *m.key.verification.start* event. /// to-device verification flows this will be the transaction id of the
/// *m.key.verification.start* event.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> { pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
self.olm self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
.lock()
.await
.as_ref()
.and_then(|o| o.get_verification(flow_id))
} }
/// Get a specific device of a user. /// Get a specific device of a user.
@ -1283,10 +1253,12 @@ impl BaseClient {
/// Get the user login session. /// Get the user login session.
/// ///
/// If the client is currently logged in, this will return a `matrix_sdk::Session` object which /// If the client is currently logged in, this will return a
/// can later be given to `restore_login`. /// `matrix_sdk::Session` object which can later be given to
/// `restore_login`.
/// ///
/// Returns a session object if the client is logged in. Otherwise returns `None`. /// Returns a session object if the client is logged in. Otherwise returns
/// `None`.
pub async fn get_session(&self) -> Option<Session> { pub async fn get_session(&self) -> Option<Session> {
self.session.read().await.clone() self.session.read().await.clone()
} }
@ -1348,8 +1320,9 @@ impl BaseClient {
/// Get the push rules. /// Get the push rules.
/// ///
/// Gets the push rules from `changes` if they have been updated, otherwise get them from the /// Gets the push rules from `changes` if they have been updated, otherwise
/// store. As a fallback, uses `Ruleset::server_default` if the user is logged in. /// get them from the store. As a fallback, uses
/// `Ruleset::server_default` if the user is logged in.
pub async fn get_push_rules(&self, changes: &StateChanges) -> Result<Ruleset> { pub async fn get_push_rules(&self, changes: &StateChanges) -> Result<Ruleset> {
if let Some(AnyGlobalAccountDataEvent::PushRules(event)) = changes if let Some(AnyGlobalAccountDataEvent::PushRules(event)) = changes
.account_data .account_data
@ -1373,11 +1346,11 @@ impl BaseClient {
/// Get the push context for the given room. /// Get the push context for the given room.
/// ///
/// Tries to get the data from `changes` or the up to date `room_info`. Loads the data from the /// Tries to get the data from `changes` or the up to date `room_info`.
/// store otherwise. /// Loads the data from the store otherwise.
/// ///
/// Returns `None` if some data couldn't be found. This should only happen in brand new rooms, /// Returns `None` if some data couldn't be found. This should only happen
/// while we process its state. /// in brand new rooms, while we process its state.
pub async fn get_push_room_context( pub async fn get_push_room_context(
&self, &self,
room: &Room, room: &Room,
@ -1389,16 +1362,10 @@ impl BaseClient {
let member_count = room_info.active_members_count(); let member_count = room_info.active_members_count();
let user_display_name = if let Some(member) = changes let user_display_name = if let Some(member) =
.members changes.members.get(room_id).and_then(|members| members.get(user_id))
.get(room_id)
.and_then(|members| members.get(user_id))
{ {
member member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
.content
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
} else if let Some(member) = room.get_member(user_id).await? { } else if let Some(member) = room.get_member(user_id).await? {
member.name().to_owned() member.name().to_owned()
} else { } else {
@ -1448,16 +1415,10 @@ impl BaseClient {
push_rules.member_count = UInt::new(room_info.active_members_count()).unwrap_or(UInt::MAX); push_rules.member_count = UInt::new(room_info.active_members_count()).unwrap_or(UInt::MAX);
if let Some(member) = changes if let Some(member) = changes.members.get(room_id).and_then(|members| members.get(user_id))
.members
.get(room_id)
.and_then(|members| members.get(user_id))
{ {
push_rules.user_display_name = member push_rules.user_display_name =
.content member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
} }
if let Some(AnySyncStateEvent::RoomPowerLevels(event)) = changes if let Some(AnySyncStateEvent::RoomPowerLevels(event)) = changes

View File

@ -28,7 +28,8 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
/// Internal representation of errors. /// Internal representation of errors.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
/// Queried endpoint requires authentication but was called on an anonymous client. /// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")] #[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired, AuthenticationRequired,

View File

@ -66,20 +66,12 @@ impl BaseRoomInfo {
let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1); let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1);
if heroes_count >= invited_joined { if heroes_count >= invited_joined {
let mut names = heroes let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
// stabilize ordering // stabilize ordering
names.sort_unstable(); names.sort_unstable();
names.join(", ") names.join(", ")
} else if heroes_count < invited_joined && invited_joined > 1 { } else if heroes_count < invited_joined && invited_joined > 1 {
let mut names = heroes let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
names.sort_unstable(); names.sort_unstable();
// TODO: What length does the spec want us to use here and in // TODO: What length does the spec want us to use here and in
@ -144,10 +136,8 @@ impl BaseRoomInfo {
true true
} }
AnyStateEventContent::RoomPowerLevels(p) => { AnyStateEventContent::RoomPowerLevels(p) => {
let max_power_level = p let max_power_level =
.users p.users.values().fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
.values()
.fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
self.max_power_level = max_power_level; self.max_power_level = max_power_level;
true true
} }

View File

@ -43,7 +43,8 @@ use crate::{
store::{Result as StoreResult, StateStore}, store::{Result as StoreResult, StateStore},
}; };
/// The underlying room data structure collecting state for joined, left and invtied rooms. /// The underlying room data structure collecting state for joined, left and
/// invtied rooms.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Room { pub struct Room {
room_id: Arc<RoomId>, room_id: Arc<RoomId>,
@ -134,7 +135,8 @@ impl Room {
/// Check if the room has it's members fully synced. /// Check if the room has it's members fully synced.
/// ///
/// Members might be missing if lazy member loading was enabled for the sync. /// Members might be missing if lazy member loading was enabled for the
/// sync.
/// ///
/// Returns true if no members are missing, false otherwise. /// Returns true if no members are missing, false otherwise.
pub fn are_members_synced(&self) -> bool { pub fn are_members_synced(&self) -> bool {
@ -199,12 +201,7 @@ impl Room {
/// Get the history visibility policy of this room. /// Get the history visibility policy of this room.
pub fn history_visibility(&self) -> HistoryVisibility { pub fn history_visibility(&self) -> HistoryVisibility {
self.inner self.inner.read().unwrap().base_info.history_visibility.clone()
.read()
.unwrap()
.base_info
.history_visibility
.clone()
} }
/// Is the room considered to be public. /// Is the room considered to be public.
@ -366,9 +363,7 @@ impl Room {
); );
let inner = self.inner.read().unwrap(); let inner = self.inner.read().unwrap();
Ok(inner Ok(inner.base_info.calculate_room_name(joined, invited, members))
.base_info
.calculate_room_name(joined, invited, members))
} }
pub(crate) fn clone_info(&self) -> RoomInfo { pub(crate) fn clone_info(&self) -> RoomInfo {
@ -393,11 +388,8 @@ impl Room {
return Ok(None); return Ok(None);
}; };
let presence = self let presence =
.store self.store.get_presence_event(user_id).await?.and_then(|e| e.deserialize().ok());
.get_presence_event(user_id)
.await?
.and_then(|e| e.deserialize().ok());
let profile = self.store.get_profile(self.room_id(), user_id).await?; let profile = self.store.get_profile(self.room_id(), user_id).await?;
let max_power_level = self.max_power_level(); let max_power_level = self.max_power_level();
let is_room_creator = self let is_room_creator = self
@ -410,8 +402,8 @@ impl Room {
.map(|c| &c.creator == user_id) .map(|c| &c.creator == user_id)
.unwrap_or(false); .unwrap_or(false);
let power = self let power =
.store self.store
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "") .get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
.await? .await?
.and_then(|e| e.deserialize().ok()) .and_then(|e| e.deserialize().ok())
@ -427,11 +419,7 @@ impl Room {
.store .store
.get_users_with_display_name( .get_users_with_display_name(
self.room_id(), self.room_id(),
member_event member_event.content.displayname.as_deref().unwrap_or_else(|| user_id.localpart()),
.content
.displayname
.as_deref()
.unwrap_or_else(|| user_id.localpart()),
) )
.await? .await?
.len() .len()
@ -557,8 +545,6 @@ impl RoomInfo {
/// ///
/// The return value is saturated at `u64::MAX`. /// The return value is saturated at `u64::MAX`.
pub fn active_members_count(&self) -> u64 { pub fn active_members_count(&self) -> u64 {
self.summary self.summary.joined_member_count.saturating_add(self.summary.invited_member_count)
.joined_member_count
.saturating_add(self.summary.invited_member_count)
} }
} }

View File

@ -49,11 +49,8 @@ impl AmbiguityMap {
} }
fn add(&mut self, user_id: UserId) -> Option<UserId> { fn add(&mut self, user_id: UserId) -> Option<UserId> {
let ambiguous_user = if self.user_count() == 1 { let ambiguous_user =
self.users.iter().next().cloned() if self.user_count() == 1 { self.users.iter().next().cloned() } else { None };
} else {
None
};
self.users.insert(user_id); self.users.insert(user_id);
@ -71,11 +68,7 @@ impl AmbiguityMap {
impl AmbiguityCache { impl AmbiguityCache {
pub fn new(store: Store) -> Self { pub fn new(store: Store) -> Self {
Self { Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() }
store,
cache: BTreeMap::new(),
changes: BTreeMap::new(),
}
} }
pub async fn handle_event( pub async fn handle_event(
@ -113,12 +106,9 @@ impl AmbiguityCache {
return Ok(()); return Ok(());
} }
let disambiguated_member = old_map let disambiguated_member = old_map.as_mut().and_then(|o| o.remove(&member_event.state_key));
.as_mut() let ambiguated_member =
.and_then(|o| o.remove(&member_event.state_key)); new_map.as_mut().and_then(|n| n.add(member_event.state_key.clone()));
let ambiguated_member = new_map
.as_mut()
.and_then(|n| n.add(member_event.state_key.clone()));
let ambiguous = new_map.as_ref().map(|n| n.is_ambiguous()).unwrap_or(false); let ambiguous = new_map.as_ref().map(|n| n.is_ambiguous()).unwrap_or(false);
self.update(room_id, old_map, new_map); self.update(room_id, old_map, new_map);
@ -129,11 +119,7 @@ impl AmbiguityCache {
member_ambiguous: ambiguous, member_ambiguous: ambiguous,
}; };
trace!( trace!("Handling display name ambiguity for {}: {:#?}", member_event.state_key, change);
"Handling display name ambiguity for {}: {:#?}",
member_event.state_key,
change
);
self.add_change(room_id, member_event.event_id.clone(), change); self.add_change(room_id, member_event.event_id.clone(), change);
@ -146,10 +132,7 @@ impl AmbiguityCache {
old_map: Option<AmbiguityMap>, old_map: Option<AmbiguityMap>,
new_map: Option<AmbiguityMap>, new_map: Option<AmbiguityMap>,
) { ) {
let entry = self let entry = self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new);
.cache
.entry(room_id.clone())
.or_insert_with(BTreeMap::new);
if let Some(old) = old_map { if let Some(old) = old_map {
entry.insert(old.display_name, old.users); entry.insert(old.display_name, old.users);
@ -161,10 +144,7 @@ impl AmbiguityCache {
} }
fn add_change(&mut self, room_id: &RoomId, event_id: EventId, change: AmbiguityChange) { fn add_change(&mut self, room_id: &RoomId, event_id: EventId, change: AmbiguityChange) {
self.changes self.changes.entry(room_id.clone()).or_insert_with(BTreeMap::new).insert(event_id, change);
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.insert(event_id, change);
} }
async fn get( async fn get(
@ -175,16 +155,12 @@ impl AmbiguityCache {
) -> Result<(Option<AmbiguityMap>, Option<AmbiguityMap>)> { ) -> Result<(Option<AmbiguityMap>, Option<AmbiguityMap>)> {
use MembershipState::*; use MembershipState::*;
let old_event = if let Some(m) = changes let old_event = if let Some(m) =
.members changes.members.get(room_id).and_then(|m| m.get(&member_event.state_key))
.get(room_id)
.and_then(|m| m.get(&member_event.state_key))
{ {
Some(m.clone()) Some(m.clone())
} else { } else {
self.store self.store.get_member_event(room_id, &member_event.state_key).await?
.get_member_event(room_id, &member_event.state_key)
.await?
}; };
let old_display_name = if let Some(event) = old_event { let old_display_name = if let Some(event) = old_event {
@ -216,23 +192,15 @@ impl AmbiguityCache {
}; };
let old_map = if let Some(old_name) = old_display_name.as_deref() { let old_map = if let Some(old_name) = old_display_name.as_deref() {
let old_display_name_map = if let Some(u) = self let old_display_name_map = if let Some(u) =
.cache self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new).get(old_name)
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.get(old_name)
{ {
u.clone() u.clone()
} else { } else {
self.store self.store.get_users_with_display_name(&room_id, &old_name).await?
.get_users_with_display_name(&room_id, &old_name)
.await?
}; };
Some(AmbiguityMap { Some(AmbiguityMap { display_name: old_name.to_string(), users: old_display_name_map })
display_name: old_name.to_string(),
users: old_display_name_map,
})
} else { } else {
None None
}; };
@ -244,8 +212,9 @@ impl AmbiguityCache {
.as_deref() .as_deref()
.unwrap_or_else(|| member_event.state_key.localpart()); .unwrap_or_else(|| member_event.state_key.localpart());
// We don't allow other users to set the display name, so if we have // We don't allow other users to set the display name, so if we
// a more trusted version of the display name use that. // have a more trusted version of the display
// name use that.
let new_display_name = if member_event.sender.as_str() == member_event.state_key { let new_display_name = if member_event.sender.as_str() == member_event.state_key {
new new
} else if let Some(old) = old_display_name.as_deref() { } else if let Some(old) = old_display_name.as_deref() {
@ -262,9 +231,7 @@ impl AmbiguityCache {
{ {
u.clone() u.clone()
} else { } else {
self.store self.store.get_users_with_display_name(&room_id, &new_display_name).await?
.get_users_with_display_name(&room_id, &new_display_name)
.await?
}; };
Some(AmbiguityMap { Some(AmbiguityMap {

View File

@ -80,8 +80,7 @@ impl MemoryStore {
} }
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.filters self.filters.insert(filter_name.to_string(), filter_id.to_string());
.insert(filter_name.to_string(), filter_id.to_string());
Ok(()) Ok(())
} }
@ -162,8 +161,7 @@ impl MemoryStore {
} }
for (event_type, event) in &changes.account_data { for (event_type, event) in &changes.account_data {
self.account_data self.account_data.insert(event_type.to_string(), event.clone());
.insert(event_type.to_string(), event.clone());
} }
for (room, events) in &changes.room_account_data { for (room, events) in &changes.room_account_data {
@ -197,8 +195,7 @@ impl MemoryStore {
} }
for (room_id, info) in &changes.invited_room_info { for (room_id, info) in &changes.invited_room_info {
self.stripped_room_info self.stripped_room_info.insert(room_id.clone(), info.clone());
.insert(room_id.clone(), info.clone());
} }
for (room, events) in &changes.stripped_members { for (room, events) in &changes.stripped_members {
@ -241,8 +238,7 @@ impl MemoryStore {
) -> Result<Option<Raw<AnySyncStateEvent>>> { ) -> Result<Option<Raw<AnySyncStateEvent>>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self.room_state.get(room_id).and_then(|e| { Ok(self.room_state.get(room_id).and_then(|e| {
e.get(event_type.as_ref()) e.get(event_type.as_ref()).and_then(|s| s.get(state_key).map(|e| e.clone()))
.and_then(|s| s.get(state_key).map(|e| e.clone()))
})) }))
} }
@ -252,10 +248,7 @@ impl MemoryStore {
user_id: &UserId, user_id: &UserId,
) -> Result<Option<MemberEventContent>> { ) -> Result<Option<MemberEventContent>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self Ok(self.profiles.get(room_id).and_then(|p| p.get(user_id).map(|p| p.clone())))
.profiles
.get(room_id)
.and_then(|p| p.get(user_id).map(|p| p.clone())))
} }
async fn get_member_event( async fn get_member_event(
@ -264,10 +257,7 @@ impl MemoryStore {
state_key: &UserId, state_key: &UserId,
) -> Result<Option<MemberEvent>> { ) -> Result<Option<MemberEvent>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self Ok(self.members.get(room_id).and_then(|m| m.get(state_key).map(|m| m.clone())))
.members
.get(room_id)
.and_then(|m| m.get(state_key).map(|m| m.clone())))
} }
fn get_user_ids(&self, room_id: &RoomId) -> Vec<UserId> { fn get_user_ids(&self, room_id: &RoomId) -> Vec<UserId> {
@ -308,10 +298,7 @@ impl MemoryStore {
&self, &self,
event_type: EventType, event_type: EventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> { ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
Ok(self Ok(self.account_data.get(event_type.as_ref()).map(|e| e.clone()))
.account_data
.get(event_type.as_ref())
.map(|e| e.clone()))
} }
async fn get_room_account_data_event( async fn get_room_account_data_event(

View File

@ -200,7 +200,8 @@ pub trait StateStore: AsyncTraitDeps {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `room_id` - The id of the room for which the room account data event should /// * `room_id` - The id of the room for which the room account data event
/// should
/// be fetched. /// be fetched.
/// ///
/// * `event_type` - The event type of the room account data event. /// * `event_type` - The event type of the room account data event.
@ -297,16 +298,12 @@ impl Store {
/// Get all the rooms this store knows about. /// Get all the rooms this store knows about.
pub fn get_rooms(&self) -> Vec<Room> { pub fn get_rooms(&self) -> Vec<Room> {
self.rooms self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect()
.iter()
.filter_map(|r| self.get_room(r.key()))
.collect()
} }
/// Get the room with the given room id. /// Get the room with the given room id.
pub fn get_room(&self, room_id: &RoomId) -> Option<Room> { pub fn get_room(&self, room_id: &RoomId) -> Option<Room> {
self.get_bare_room(room_id) self.get_bare_room(room_id).and_then(|r| match r.room_type() {
.and_then(|r| match r.room_type() {
RoomType::Joined => Some(r), RoomType::Joined => Some(r),
RoomType::Left => Some(r), RoomType::Left => Some(r),
RoomType::Invited => self.get_stripped_room(room_id), RoomType::Invited => self.get_stripped_room(room_id),
@ -320,10 +317,7 @@ impl Store {
pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room { pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room {
let session = self.session.read().await; let session = self.session.read().await;
let user_id = &session let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
self.stripped_rooms self.stripped_rooms
.entry(room_id.clone()) .entry(room_id.clone())
@ -333,10 +327,7 @@ impl Store {
pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room { pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room {
let session = self.session.read().await; let session = self.session.read().await;
let user_id = &session let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
self.rooms self.rooms
.entry(room_id.clone()) .entry(room_id.clone())
@ -358,8 +349,8 @@ impl Deref for Store {
pub struct StateChanges { pub struct StateChanges {
/// The sync token that relates to this update. /// The sync token that relates to this update.
pub sync_token: Option<String>, pub sync_token: Option<String>,
/// A user session, containing an access token and information about the associated user /// A user session, containing an access token and information about the
/// account. /// associated user account.
pub session: Option<Session>, pub session: Option<Session>,
/// A mapping of event type string to `AnyBasicEvent`. /// A mapping of event type string to `AnyBasicEvent`.
pub account_data: BTreeMap<String, Raw<AnyGlobalAccountDataEvent>>, pub account_data: BTreeMap<String, Raw<AnyGlobalAccountDataEvent>>,
@ -371,7 +362,8 @@ pub struct StateChanges {
/// A mapping of `RoomId` to a map of users and their `MemberEventContent`. /// A mapping of `RoomId` to a map of users and their `MemberEventContent`.
pub profiles: BTreeMap<RoomId, BTreeMap<UserId, MemberEventContent>>, pub profiles: BTreeMap<RoomId, BTreeMap<UserId, MemberEventContent>>,
/// A mapping of `RoomId` to a map of event type string to a state key and `AnySyncStateEvent`. /// A mapping of `RoomId` to a map of event type string to a state key and
/// `AnySyncStateEvent`.
pub state: BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnySyncStateEvent>>>>, pub state: BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnySyncStateEvent>>>>,
/// A mapping of `RoomId` to a map of event type string to `AnyBasicEvent`. /// A mapping of `RoomId` to a map of event type string to `AnyBasicEvent`.
pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>, pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>,
@ -397,10 +389,7 @@ pub struct StateChanges {
impl StateChanges { impl StateChanges {
/// Create a new `StateChanges` struct with the given sync_token. /// Create a new `StateChanges` struct with the given sync_token.
pub fn new(sync_token: String) -> Self { pub fn new(sync_token: String) -> Self {
Self { Self { sync_token: Some(sync_token), ..Default::default() }
sync_token: Some(sync_token),
..Default::default()
}
} }
/// Update the `StateChanges` struct with the given `PresenceEvent`. /// Update the `StateChanges` struct with the given `PresenceEvent`.
@ -410,14 +399,12 @@ impl StateChanges {
/// Update the `StateChanges` struct with the given `RoomInfo`. /// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_room(&mut self, room: RoomInfo) { pub fn add_room(&mut self, room: RoomInfo) {
self.room_infos self.room_infos.insert(room.room_id.as_ref().to_owned(), room);
.insert(room.room_id.as_ref().to_owned(), room);
} }
/// Update the `StateChanges` struct with the given `RoomInfo`. /// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_stripped_room(&mut self, room: RoomInfo) { pub fn add_stripped_room(&mut self, room: RoomInfo) {
self.invited_room_info self.invited_room_info.insert(room.room_id.as_ref().to_owned(), room);
.insert(room.room_id.as_ref().to_owned(), room);
} }
/// Update the `StateChanges` struct with the given `AnyBasicEvent`. /// Update the `StateChanges` struct with the given `AnyBasicEvent`.
@ -426,11 +413,11 @@ impl StateChanges {
event: AnyGlobalAccountDataEvent, event: AnyGlobalAccountDataEvent,
raw_event: Raw<AnyGlobalAccountDataEvent>, raw_event: Raw<AnyGlobalAccountDataEvent>,
) { ) {
self.account_data self.account_data.insert(event.content().event_type().to_owned(), raw_event);
.insert(event.content().event_type().to_owned(), raw_event);
} }
/// Update the `StateChanges` struct with the given room with a new `AnyBasicEvent`. /// Update the `StateChanges` struct with the given room with a new
/// `AnyBasicEvent`.
pub fn add_room_account_data( pub fn add_room_account_data(
&mut self, &mut self,
room_id: &RoomId, room_id: &RoomId,
@ -443,7 +430,8 @@ impl StateChanges {
.insert(event.content().event_type().to_owned(), raw_event); .insert(event.content().event_type().to_owned(), raw_event);
} }
/// Update the `StateChanges` struct with the given room with a new `StrippedMemberEvent`. /// Update the `StateChanges` struct with the given room with a new
/// `StrippedMemberEvent`.
pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) { pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) {
let user_id = event.state_key.clone(); let user_id = event.state_key.clone();
@ -453,7 +441,8 @@ impl StateChanges {
.insert(user_id, event); .insert(user_id, event);
} }
/// Update the `StateChanges` struct with the given room with a new `AnySyncStateEvent`. /// Update the `StateChanges` struct with the given room with a new
/// `AnySyncStateEvent`.
pub fn add_state_event( pub fn add_state_event(
&mut self, &mut self,
room_id: &RoomId, room_id: &RoomId,
@ -468,11 +457,9 @@ impl StateChanges {
.insert(event.state_key().to_string(), raw_event); .insert(event.state_key().to_string(), raw_event);
} }
/// Update the `StateChanges` struct with the given room with a new `Notification`. /// Update the `StateChanges` struct with the given room with a new
/// `Notification`.
pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) { pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) {
self.notifications self.notifications.entry(room_id.to_owned()).or_insert_with(Vec::new).push(notification);
.entry(room_id.to_owned())
.or_insert_with(Vec::new)
.push(notification);
} }
} }

View File

@ -108,13 +108,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) { impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[ [self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
} }
} }
@ -164,9 +158,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path { if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish() f.debug_struct("SledStore").field("path", &path).finish()
} else { } else {
f.debug_struct("SledStore") f.debug_struct("SledStore").field("path", &"memory store").finish()
.field("path", &"memory store")
.finish()
} }
} }
} }
@ -236,8 +228,7 @@ impl SledStore {
} else { } else {
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?; let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
let encrypted_key = DatabaseType::Encrypted( let encrypted_key = DatabaseType::Encrypted(
key.export(passphrase) key.export(passphrase).map_err::<StoreError, _>(|e| e.into())?,
.map_err::<StoreError, _>(|e| e.into())?,
); );
db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?; db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?;
key key
@ -275,8 +266,7 @@ impl SledStore {
} }
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.session self.session.insert(("filter", filter_name).encode(), filter_id)?;
.insert(("filter", filter_name).encode(), filter_id)?;
Ok(()) Ok(())
} }
@ -476,11 +466,7 @@ impl SledStore {
} }
pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> { pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
Ok(self Ok(self.presence.get(user_id.encode())?.map(|e| self.deserialize_event(&e)).transpose()?)
.presence
.get(user_id.encode())?
.map(|e| self.deserialize_event(&e))
.transpose()?)
} }
pub async fn get_state_event( pub async fn get_state_event(
@ -531,14 +517,10 @@ impl SledStore {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> impl Stream<Item = Result<UserId>> { ) -> impl Stream<Item = Result<UserId>> {
stream::iter( stream::iter(self.invited_user_ids.scan_prefix(room_id.encode()).map(|u| {
self.invited_user_ids
.scan_prefix(room_id.encode())
.map(|u| {
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string()) UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier) .map_err(StoreError::Identifier)
}), }))
)
} }
pub async fn get_joined_user_ids( pub async fn get_joined_user_ids(
@ -554,9 +536,7 @@ impl SledStore {
pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> { pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
let db = self.clone(); let db = self.clone();
stream::iter( stream::iter(
self.room_info self.room_info.iter().map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
.iter()
.map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
) )
} }
@ -680,8 +660,7 @@ impl StateStore for SledStore {
room_id: &RoomId, room_id: &RoomId,
display_name: &str, display_name: &str,
) -> Result<BTreeSet<UserId>> { ) -> Result<BTreeSet<UserId>> {
self.get_users_with_display_name(room_id, display_name) self.get_users_with_display_name(room_id, display_name).await
.await
} }
async fn get_account_data_event( async fn get_account_data_event(
@ -767,11 +746,7 @@ mod test {
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let user_id = user_id(); let user_id = user_id();
assert!(store assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_none());
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_none());
let mut changes = StateChanges::default(); let mut changes = StateChanges::default();
changes changes
.members .members
@ -780,11 +755,7 @@ mod test {
.insert(user_id.clone(), membership_event()); .insert(user_id.clone(), membership_event());
store.save_changes(&changes).await.unwrap(); store.save_changes(&changes).await.unwrap();
assert!(store assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_some());
} }
#[async_test] #[async_test]

View File

@ -169,10 +169,7 @@ impl StoreKey {
cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?; cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?;
Ok(EncryptedStoreKey { Ok(EncryptedStoreKey {
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: KDF_ROUNDS, kdf_salt: salt },
rounds: KDF_ROUNDS,
kdf_salt: salt,
},
ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext }, ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext },
}) })
} }
@ -195,11 +192,7 @@ impl StoreKey {
let ciphertext = cipher.encrypt(xnonce, event.as_ref())?; let ciphertext = cipher.encrypt(xnonce, event.as_ref())?;
Ok(EncryptedEvent { Ok(EncryptedEvent { version: VERSION, ciphertext, nonce })
version: VERSION,
ciphertext,
nonce,
})
} }
pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> { pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> {

View File

@ -104,16 +104,14 @@ pub struct SyncRoomEvent {
impl From<Raw<AnySyncRoomEvent>> for SyncRoomEvent { impl From<Raw<AnySyncRoomEvent>> for SyncRoomEvent {
fn from(inner: Raw<AnySyncRoomEvent>) -> Self { fn from(inner: Raw<AnySyncRoomEvent>) -> Self {
Self { Self { encryption_info: None, event: inner }
encryption_info: None,
event: inner,
}
} }
} }
#[derive(Clone, Debug, Default, Deserialize, Serialize)] #[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct SyncResponse { pub struct SyncResponse {
/// The batch token to supply in the `since` param of the next `/sync` request. /// The batch token to supply in the `since` param of the next `/sync`
/// request.
pub next_batch: String, pub next_batch: String,
/// Updates to rooms. /// Updates to rooms.
pub rooms: Rooms, pub rooms: Rooms,
@ -138,10 +136,7 @@ pub struct SyncResponse {
impl SyncResponse { impl SyncResponse {
pub fn new(next_batch: String) -> Self { pub fn new(next_batch: String) -> Self {
Self { Self { next_batch, ..Default::default() }
next_batch,
..Default::default()
}
} }
} }
@ -162,14 +157,15 @@ pub struct JoinedRoom {
pub unread_notifications: UnreadNotificationsCount, pub unread_notifications: UnreadNotificationsCount,
/// The timeline of messages and state changes in the room. /// The timeline of messages and state changes in the room.
pub timeline: Timeline, pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start /// Updates to the state, between the time indicated by the `since`
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not /// parameter, and the start of the `timeline` (or all state up to the
/// given, or `full_state` is true). /// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State, pub state: State,
/// The private data that this user has attached to this room. /// The private data that this user has attached to this room.
pub account_data: RoomAccountData, pub account_data: RoomAccountData,
/// The ephemeral events in the room that aren't recorded in the timeline or state of the /// The ephemeral events in the room that aren't recorded in the timeline or
/// room. e.g. typing. /// state of the room. e.g. typing.
pub ephemeral: Ephemeral, pub ephemeral: Ephemeral,
} }
@ -181,20 +177,15 @@ impl JoinedRoom {
ephemeral: Ephemeral, ephemeral: Ephemeral,
unread_notifications: UnreadNotificationsCount, unread_notifications: UnreadNotificationsCount,
) -> Self { ) -> Self {
Self { Self { unread_notifications, timeline, state, account_data, ephemeral }
unread_notifications,
timeline,
state,
account_data,
ephemeral,
}
} }
} }
/// Counts of unread notifications for a room. /// Counts of unread notifications for a room.
#[derive(Copy, Clone, Debug, Default, Deserialize, Serialize)] #[derive(Copy, Clone, Debug, Default, Deserialize, Serialize)]
pub struct UnreadNotificationsCount { pub struct UnreadNotificationsCount {
/// The number of unread notifications for this room with the highlight flag set. /// The number of unread notifications for this room with the highlight flag
/// set.
pub highlight_count: u64, pub highlight_count: u64,
/// The total number of unread notifications for this room. /// The total number of unread notifications for this room.
pub notification_count: u64, pub notification_count: u64,
@ -204,10 +195,7 @@ impl From<RumaUnreadNotificationsCount> for UnreadNotificationsCount {
fn from(notifications: RumaUnreadNotificationsCount) -> Self { fn from(notifications: RumaUnreadNotificationsCount) -> Self {
Self { Self {
highlight_count: notifications.highlight_count.map(|c| c.into()).unwrap_or(0), highlight_count: notifications.highlight_count.map(|c| c.into()).unwrap_or(0),
notification_count: notifications notification_count: notifications.notification_count.map(|c| c.into()).unwrap_or(0),
.notification_count
.map(|c| c.into())
.unwrap_or(0),
} }
} }
} }
@ -217,9 +205,10 @@ pub struct LeftRoom {
/// The timeline of messages and state changes in the room up to the point /// The timeline of messages and state changes in the room up to the point
/// when the user left. /// when the user left.
pub timeline: Timeline, pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start /// Updates to the state, between the time indicated by the `since`
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not /// parameter, and the start of the `timeline` (or all state up to the
/// given, or `full_state` is true). /// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State, pub state: State,
/// The private data that this user has attached to this room. /// The private data that this user has attached to this room.
pub account_data: RoomAccountData, pub account_data: RoomAccountData,
@ -227,18 +216,15 @@ pub struct LeftRoom {
impl LeftRoom { impl LeftRoom {
pub fn new(timeline: Timeline, state: State, account_data: RoomAccountData) -> Self { pub fn new(timeline: Timeline, state: State, account_data: RoomAccountData) -> Self {
Self { Self { timeline, state, account_data }
timeline,
state,
account_data,
}
} }
} }
/// Events in the room. /// Events in the room.
#[derive(Clone, Debug, Default, Deserialize, Serialize)] #[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Timeline { pub struct Timeline {
/// True if the number of events returned was limited by the `limit` on the filter. /// True if the number of events returned was limited by the `limit` on the
/// filter.
pub limited: bool, pub limited: bool,
/// A token that can be supplied to to the `from` parameter of the /// A token that can be supplied to to the `from` parameter of the
@ -251,11 +237,7 @@ pub struct Timeline {
impl Timeline { impl Timeline {
pub fn new(limited: bool, prev_batch: Option<String>) -> Self { pub fn new(limited: bool, prev_batch: Option<String>) -> Self {
Self { Self { limited, prev_batch, ..Default::default() }
limited,
prev_batch,
..Default::default()
}
} }
} }

View File

@ -49,17 +49,12 @@ fn huge_keys_query_resopnse() -> get_keys::Response {
} }
pub fn keys_query(c: &mut Criterion) { pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread() let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = keys_query_response(); let response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let count = response let count = response.device_keys.values().fold(0, |acc, d| acc + d.len())
.device_keys
.values()
.fold(0, |acc, d| acc + d.len())
+ response.master_keys.len() + response.master_keys.len()
+ response.self_signing_keys.len() + response.self_signing_keys.len()
+ response.user_signing_keys.len(); + response.user_signing_keys.len();
@ -69,14 +64,10 @@ pub fn keys_query(c: &mut Criterion) {
let name = format!("{} device and cross signing keys", count); let name = format!("{} device and cross signing keys", count);
group.bench_with_input( group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.to_async(&runtime) b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() }) .iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
}, });
);
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let machine = runtime let machine = runtime
@ -88,44 +79,30 @@ pub fn keys_query(c: &mut Criterion) {
)) ))
.unwrap(); .unwrap();
group.bench_with_input( group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.to_async(&runtime) b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() }) .iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
}, });
);
group.finish() group.finish()
} }
pub fn keys_claiming(c: &mut Criterion) { pub fn keys_claiming(c: &mut Criterion) {
let runtime = Arc::new( let runtime = Arc::new(Builder::new_multi_thread().build().expect("Can't create runtime"));
Builder::new_multi_thread()
.build()
.expect("Can't create runtime"),
);
let keys_query_response = keys_query_response(); let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let response = keys_claim_response(); let response = keys_claim_response();
let count = response let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Olm session creation"); let mut group = c.benchmark_group("Olm session creation");
group.throughput(Throughput::Elements(count as u64)); group.throughput(Throughput::Elements(count as u64));
let name = format!("{} one-time keys", count); let name = format!("{} one-time keys", count);
group.bench_with_input( group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.iter_batched( b.iter_batched(
|| { || {
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
@ -135,19 +112,13 @@ pub fn keys_claiming(c: &mut Criterion) {
(machine, runtime.clone()) (machine, runtime.clone())
}, },
move |(machine, runtime)| { move |(machine, runtime)| {
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
}, },
BatchSize::SmallInput, BatchSize::SmallInput,
) )
}, });
);
group.bench_with_input( group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.iter_batched( b.iter_batched(
|| { || {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
@ -165,22 +136,17 @@ pub fn keys_claiming(c: &mut Criterion) {
(machine, runtime.clone()) (machine, runtime.clone())
}, },
move |(machine, runtime)| { move |(machine, runtime)| {
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
}, },
BatchSize::SmallInput, BatchSize::SmallInput,
) )
}, });
);
group.finish() group.finish()
} }
pub fn room_key_sharing(c: &mut Criterion) { pub fn room_key_sharing(c: &mut Criterion) {
let runtime = Builder::new_multi_thread() let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
.build()
.expect("Can't create runtime");
let keys_query_response = keys_query_response(); let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
@ -190,18 +156,11 @@ pub fn room_key_sharing(c: &mut Criterion) {
let to_device_response = ToDeviceResponse::new(); let to_device_response = ToDeviceResponse::new();
let users: Vec<UserId> = keys_query_response.device_keys.keys().cloned().collect(); let users: Vec<UserId> = keys_query_response.device_keys.keys().cloned().collect();
let count = response let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)) runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
let mut group = c.benchmark_group("Room key sharing"); let mut group = c.benchmark_group("Room key sharing");
group.throughput(Throughput::Elements(count as u64)); group.throughput(Throughput::Elements(count as u64));
@ -217,10 +176,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty()); assert!(!requests.is_empty());
for request in requests { for request in requests {
machine machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
} }
machine.invalidate_group_session(&room_id).await.unwrap(); machine.invalidate_group_session(&room_id).await.unwrap();
@ -236,12 +192,8 @@ pub fn room_key_sharing(c: &mut Criterion) {
None, None,
)) ))
.unwrap(); .unwrap();
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)) runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| { group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime).iter(|| async { b.to_async(&runtime).iter(|| async {
@ -253,10 +205,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty()); assert!(!requests.is_empty());
for request in requests { for request in requests {
machine machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
} }
machine.invalidate_group_session(&room_id).await.unwrap(); machine.invalidate_group_session(&room_id).await.unwrap();
@ -267,28 +216,21 @@ pub fn room_key_sharing(c: &mut Criterion) {
} }
pub fn devices_missing_sessions_collecting(c: &mut Criterion) { pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
let runtime = Builder::new_multi_thread() let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = huge_keys_query_resopnse(); let response = huge_keys_query_resopnse();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let users: Vec<UserId> = response.device_keys.keys().cloned().collect(); let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
let count = response let count = response.device_keys.values().fold(0, |acc, d| acc + d.len());
.device_keys
.values()
.fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Devices missing sessions collecting"); let mut group = c.benchmark_group("Devices missing sessions collecting");
group.throughput(Throughput::Elements(count as u64)); group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count); let name = format!("{} devices", count);
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("memory store", &name), |b| { group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(&runtime).iter_with_large_drop(|| async { b.to_async(&runtime).iter_with_large_drop(|| async {
@ -306,9 +248,7 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
)) ))
.unwrap(); .unwrap();
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| { group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime) b.to_async(&runtime)

View File

@ -45,10 +45,7 @@ pub struct FlamegraphProfiler<'a> {
impl<'a> FlamegraphProfiler<'a> { impl<'a> FlamegraphProfiler<'a> {
pub fn new(frequency: c_int) -> Self { pub fn new(frequency: c_int) -> Self {
FlamegraphProfiler { FlamegraphProfiler { frequency, active_profiler: None }
frequency,
active_profiler: None,
}
} }
} }

View File

@ -55,10 +55,7 @@ impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
if hash.as_slice() == self.expected_hash.as_slice() { if hash.as_slice() == self.expected_hash.as_slice() {
Ok(0) Ok(0)
} else { } else {
Err(IoError::new( Err(IoError::new(ErrorKind::Other, "Hash missmatch while decrypting"))
ErrorKind::Other,
"Hash missmatch while decrypting",
))
} }
} else { } else {
self.sha.update(&buf[0..read_bytes]); self.sha.update(&buf[0..read_bytes]);
@ -126,23 +123,14 @@ impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
return Err(DecryptorError::UnknownVersion); return Err(DecryptorError::UnknownVersion);
} }
let hash = decode( let hash = decode(info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?)?;
info.hashes
.get("sha256")
.ok_or(DecryptorError::MissingHash)?,
)?;
let key = Zeroizing::from(decode_url_safe(info.web_key.k)?); let key = Zeroizing::from(decode_url_safe(info.web_key.k)?);
let iv = decode(info.iv)?; let iv = decode(info.iv)?;
let sha = Sha256::default(); let sha = Sha256::default();
let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?; let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?;
Ok(AttachmentDecryptor { Ok(AttachmentDecryptor { inner_reader: input, expected_hash: hash, sha, aes })
inner_reader: input,
expected_hash: hash,
sha,
aes,
})
} }
} }
@ -164,9 +152,7 @@ impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> {
if read_bytes == 0 { if read_bytes == 0 {
let hash = self.sha.finalize_reset(); let hash = self.sha.finalize_reset();
self.hashes self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
Ok(0) Ok(0)
} else { } else {
self.aes.apply_keystream(&mut buf[0..read_bytes]); self.aes.apply_keystream(&mut buf[0..read_bytes]);
@ -240,9 +226,7 @@ impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
/// Consume the encryptor and get the encryption key. /// Consume the encryptor and get the encryption key.
pub fn finish(mut self) -> EncryptionInfo { pub fn finish(mut self) -> EncryptionInfo {
let hash = self.sha.finalize(); let hash = self.sha.finalize();
self.hashes self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
EncryptionInfo { EncryptionInfo {
version: VERSION.to_string(), version: VERSION.to_string(),

View File

@ -98,14 +98,10 @@ pub fn decrypt_key_export(
return Err(KeyExportError::InvalidHeaders); return Err(KeyExportError::InvalidHeaders);
} }
let payload: String = x let payload: String =
.lines() x.lines().filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER))).collect();
.filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER)))
.collect();
Ok(serde_json::from_str(&decrypt_helper( Ok(serde_json::from_str(&decrypt_helper(&payload, passphrase)?)?)
&payload, passphrase,
)?)?)
} }
/// Encrypt the list of exported room keys using the given passphrase. /// Encrypt the list of exported room keys using the given passphrase.
@ -260,10 +256,7 @@ mod test {
"}; "};
fn export_wihtout_headers() -> String { fn export_wihtout_headers() -> String {
TEST_EXPORT TEST_EXPORT.lines().filter(|l| !l.starts_with("-----")).collect()
.lines()
.filter(|l| !l.starts_with("-----"))
.collect()
} }
#[test] #[test]
@ -300,14 +293,8 @@ mod test {
let (machine, _) = get_prepared_machine().await; let (machine, _) = get_prepared_machine().await;
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
machine machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
.create_outbound_group_session_with_defaults(&room_id) let export = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap();
.await
.unwrap();
let export = machine
.export_keys(|s| s.room_id() == &room_id)
.await
.unwrap();
assert!(!export.is_empty()); assert!(!export.is_empty());
@ -315,10 +302,7 @@ mod test {
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap(); let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
assert_eq!(export, decrypted); assert_eq!(export, decrypted);
assert_eq!( assert_eq!(machine.import_keys(decrypted, |_, _| {}).await.unwrap(), (0, 1));
machine.import_keys(decrypted, |_, _| {}).await.unwrap(),
(0, 1)
);
} }
#[test] #[test]

View File

@ -113,9 +113,7 @@ pub struct Device {
impl std::fmt::Debug for Device { impl std::fmt::Debug for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Device") f.debug_struct("Device").field("device", &self.inner).finish()
.field("device", &self.inner)
.finish()
} }
} }
@ -132,10 +130,7 @@ impl Device {
/// ///
/// Returns a `Sas` object and to-device request that needs to be sent out. /// Returns a `Sas` object and to-device request that needs to be sent out.
pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> { pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> {
let (sas, request) = self let (sas, request) = self.verification_machine.start_sas(self.inner.clone()).await?;
.verification_machine
.start_sas(self.inner.clone())
.await?;
if let OutgoingVerificationRequest::ToDevice(r) = request { if let OutgoingVerificationRequest::ToDevice(r) = request {
Ok((sas, r)) Ok((sas, r))
@ -155,8 +150,7 @@ impl Device {
/// Get the trust state of the device. /// Get the trust state of the device.
pub fn trust_state(&self) -> bool { pub fn trust_state(&self) -> bool {
self.inner self.inner.trust_state(&self.own_identity, &self.device_owner_identity)
.trust_state(&self.own_identity, &self.device_owner_identity)
} }
/// Set the local trust state of the device to the given state. /// Set the local trust state of the device to the given state.
@ -171,10 +165,7 @@ impl Device {
self.inner.set_trust_state(trust_state); self.inner.set_trust_state(trust_state);
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![self.inner.clone()], ..Default::default() },
changed: vec![self.inner.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -193,9 +184,7 @@ impl Device {
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> OlmResult<(Session, EncryptedEventContent)> { ) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner self.inner.encrypt(&**self.verification_machine.store, event_type, content).await
.encrypt(&**self.verification_machine.store, event_type, content)
.await
} }
/// Encrypt the given inbound group session as a forwarded room key for this /// Encrypt the given inbound group session as a forwarded room key for this
@ -254,9 +243,7 @@ impl UserDevices {
/// Returns true if there is at least one devices of this user that is /// Returns true if there is at least one devices of this user that is
/// considered to be verified, false otherwise. /// considered to be verified, false otherwise.
pub fn is_any_verified(&self) -> bool { pub fn is_any_verified(&self) -> bool {
self.inner self.inner.values().any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
.values()
.any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
@ -341,8 +328,7 @@ impl ReadOnlyDevice {
/// Get the key of the given key algorithm belonging to this device. /// Get the key of the given key algorithm belonging to this device.
pub fn get_key(&self, algorithm: DeviceKeyAlgorithm) -> Option<&String> { pub fn get_key(&self, algorithm: DeviceKeyAlgorithm) -> Option<&String> {
self.keys self.keys.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
} }
/// Get a map containing all the device keys. /// Get a map containing all the device keys.
@ -489,9 +475,8 @@ impl ReadOnlyDevice {
} }
fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> { fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
let signing_key = self let signing_key =
.get_key(DeviceKeyAlgorithm::Ed25519) self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?;
.ok_or(SignatureError::MissingSigningKey)?;
let utility = Utility::new(); let utility = Utility::new();
@ -634,10 +619,7 @@ pub(crate) mod test {
assert_eq!(device_id, device.device_id()); assert_eq!(device_id, device.device_id());
assert_eq!(device.algorithms.len(), 2); assert_eq!(device.algorithms.len(), 2);
assert_eq!(LocalTrust::Unset, device.local_trust_state()); assert_eq!(LocalTrust::Unset, device.local_trust_state());
assert_eq!( assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
assert_eq!( assert_eq!(
device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(), device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(),
"xfgbLIC5WAl1OIkpOzoxpCe8FsRDT6nch7NQsOb15nc" "xfgbLIC5WAl1OIkpOzoxpCe8FsRDT6nch7NQsOb15nc"
@ -652,10 +634,7 @@ pub(crate) mod test {
fn update_a_device() { fn update_a_device() {
let mut device = get_device(); let mut device = get_device();
assert_eq!( assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
let display_name = "Alice's work computer".to_owned(); let display_name = "Alice's work computer".to_owned();

View File

@ -54,11 +54,7 @@ impl IdentityManager {
const MAX_KEY_QUERY_USERS: usize = 250; const MAX_KEY_QUERY_USERS: usize = 250;
pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self { pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self {
IdentityManager { IdentityManager { user_id, device_id, store }
user_id,
device_id,
store,
}
} }
fn user_id(&self) -> &UserId { fn user_id(&self) -> &UserId {
@ -78,9 +74,8 @@ impl IdentityManager {
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
let changed_devices = self let changed_devices =
.handle_devices_from_key_query(response.device_keys.clone()) self.handle_devices_from_key_query(response.device_keys.clone()).await?;
.await?;
let changed_identities = self.handle_cross_singing_keys(response).await?; let changed_identities = self.handle_cross_singing_keys(response).await?;
let changes = Changes { let changes = Changes {
@ -104,9 +99,8 @@ impl IdentityManager {
store: Store, store: Store,
device_keys: DeviceKeys, device_keys: DeviceKeys,
) -> StoreResult<DeviceChange> { ) -> StoreResult<DeviceChange> {
let old_device = store let old_device =
.get_readonly_device(&device_keys.user_id, &device_keys.device_id) store.get_readonly_device(&device_keys.user_id, &device_keys.device_id).await?;
.await?;
if let Some(mut device) = old_device { if let Some(mut device) = old_device {
if let Err(e) = device.update_device(&device_keys) { if let Err(e) = device.update_device(&device_keys) {
@ -148,9 +142,7 @@ impl IdentityManager {
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect(); let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map let tasks = device_map.into_iter().filter_map(|(device_id, device_keys)| {
.into_iter()
.filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store. // We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id { if user_id == *own_user_id && device_id == *own_device_id {
None None
@ -161,10 +153,7 @@ impl IdentityManager {
); );
None None
} else { } else {
Some(spawn(Self::update_or_create_device( Some(spawn(Self::update_or_create_device(store.clone(), device_keys)))
store.clone(),
device_keys,
)))
} }
}); });
@ -211,9 +200,7 @@ impl IdentityManager {
) -> StoreResult<DeviceChanges> { ) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default(); let mut changes = DeviceChanges::default();
let tasks = device_keys_map let tasks = device_keys_map.into_iter().map(|(user_id, device_keys_map)| {
.into_iter()
.map(|(user_id, device_keys_map)| {
spawn(Self::update_user_devices( spawn(Self::update_user_devices(
self.store.clone(), self.store.clone(),
self.user_id.clone(), self.user_id.clone(),
@ -254,10 +241,7 @@ impl IdentityManager {
let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) { let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) {
SelfSigningPubkey::from(s) SelfSigningPubkey::from(s)
} else { } else {
warn!( warn!("User identity for user {} didn't contain a self signing pubkey", user_id);
"User identity for user {} didn't contain a self signing pubkey",
user_id
);
continue; continue;
}; };
@ -276,13 +260,11 @@ impl IdentityManager {
continue; continue;
}; };
identity identity.update(master_key, self_signing, user_signing).map(|_| (i, false))
.update(master_key, self_signing, user_signing) }
.map(|_| (i, false)) UserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| (i, false))
} }
UserIdentities::Other(ref mut identity) => identity
.update(master_key, self_signing)
.map(|_| (i, false)),
} }
} else if user_id == self.user_id() { } else if user_id == self.user_id() {
if let Some(s) = response.user_signing_keys.get(user_id) { if let Some(s) = response.user_signing_keys.get(user_id) {
@ -310,10 +292,7 @@ impl IdentityManager {
continue; continue;
} }
} else if master_key.user_id() != user_id || self_signing.user_id() != user_id { } else if master_key.user_id() != user_id || self_signing.user_id() != user_id {
warn!( warn!("User id mismatch in one of the cross signing keys for user {}", user_id);
"User id mismatch in one of the cross signing keys for user {}",
user_id
);
continue; continue;
} else { } else {
UserIdentity::new(master_key, self_signing) UserIdentity::new(master_key, self_signing)
@ -322,11 +301,7 @@ impl IdentityManager {
match result { match result {
Ok((i, new)) => { Ok((i, new)) => {
trace!( trace!("Updated or created new user identity for {}: {:?}", user_id, i);
"Updated or created new user identity for {}: {:?}",
user_id,
i
);
if new { if new {
changes.new.push(i); changes.new.push(i);
} else { } else {
@ -334,10 +309,7 @@ impl IdentityManager {
} }
} }
Err(e) => { Err(e) => {
warn!( warn!("Couldn't update or create new user identity for {}: {:?}", user_id, e);
"Couldn't update or create new user identity for {}: {:?}",
user_id, e
);
continue; continue;
} }
} }
@ -635,10 +607,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0); assert_eq!(devices.devices().count(), 0);
manager manager.receive_keys_query_response(&other_key_query()).await.unwrap();
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1); assert_eq!(devices.devices().count(), 1);
@ -649,12 +618,7 @@ pub(crate) mod test {
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
let identity = manager let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = identity.other().unwrap(); let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok()) assert!(identity.is_device_signed(&device).is_ok())
@ -667,10 +631,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0); assert_eq!(devices.devices().count(), 0);
manager manager.receive_keys_query_response(&other_key_query()).await.unwrap();
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1); assert_eq!(devices.devices().count(), 1);
@ -681,12 +642,7 @@ pub(crate) mod test {
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
let identity = manager let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = identity.other().unwrap(); let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok()) assert!(identity.is_device_signed(&device).is_ok())

View File

@ -225,12 +225,7 @@ impl MasterPubkey {
&self, &self,
subkey: impl Into<CrossSigningSubKeys<'a>>, subkey: impl Into<CrossSigningSubKeys<'a>>,
) -> Result<(), SignatureError> { ) -> Result<(), SignatureError> {
let (key_id, key) = self let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let key_id = DeviceKeyId::try_from(key_id.as_str())?; let key_id = DeviceKeyId::try_from(key_id.as_str())?;
@ -287,12 +282,7 @@ impl UserSigningPubkey {
&self, &self,
master_key: &MasterPubkey, master_key: &MasterPubkey,
) -> Result<(), SignatureError> { ) -> Result<(), SignatureError> {
let (key_id, key) = self let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK. // TODO check that the usage is OK.
@ -335,12 +325,7 @@ impl SelfSigningPubkey {
/// Returns an empty result if the signature check succeeded, otherwise a /// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed. /// SignatureError indicating why the check failed.
pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> { pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> {
let (key_id, key) = self let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK. // TODO check that the usage is OK.
@ -472,37 +457,16 @@ impl UserIdentity {
) -> Result<Self, SignatureError> { ) -> Result<Self, SignatureError> {
master_key.verify_subkey(&self_signing_key)?; master_key.verify_subkey(&self_signing_key)?;
Ok(Self { Ok(Self { user_id: Arc::new(master_key.0.user_id.clone()), master_key, self_signing_key })
user_id: Arc::new(master_key.0.user_id.clone()),
master_key,
self_signing_key,
})
} }
#[cfg(test)] #[cfg(test)]
pub async fn from_private(identity: &PrivateCrossSigningIdentity) -> Self { pub async fn from_private(identity: &PrivateCrossSigningIdentity) -> Self {
let master_key = identity let master_key = identity.master_key.lock().await.as_ref().unwrap().public_key.clone();
.master_key let self_signing_key =
.lock() identity.self_signing_key.lock().await.as_ref().unwrap().public_key.clone();
.await
.as_ref()
.unwrap()
.public_key
.clone();
let self_signing_key = identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key
.clone();
Self { Self { user_id: Arc::new(identity.user_id().clone()), master_key, self_signing_key }
user_id: Arc::new(identity.user_id().clone()),
master_key,
self_signing_key,
}
} }
/// Get the user id of this identity. /// Get the user id of this identity.
@ -644,8 +608,7 @@ impl OwnUserIdentity {
/// Returns an empty result if the signature check succeeded, otherwise a /// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed. /// SignatureError indicating why the check failed.
pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> { pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> {
self.user_signing_key self.user_signing_key.verify_master_key(&identity.master_key)
.verify_master_key(&identity.master_key)
} }
/// Check if the given device has been signed by this identity. /// Check if the given device has been signed by this identity.
@ -790,9 +753,8 @@ pub(crate) mod test {
assert!(identity.is_device_signed(&first).is_err()); assert!(identity.is_device_signed(&first).is_err());
assert!(identity.is_device_signed(&second).is_ok()); assert!(identity.is_device_signed(&second).is_ok());
let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty( let private_identity =
second.user_id().clone(), Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(second.user_id().clone())));
)));
let verification_machine = VerificationMachine::new( let verification_machine = VerificationMachine::new(
ReadOnlyAccount::new(second.user_id(), second.device_id()), ReadOnlyAccount::new(second.user_id(), second.device_id()),
private_identity.clone(), private_identity.clone(),

View File

@ -105,10 +105,8 @@ impl WaitQueue {
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
) -> Vec<( ) -> Vec<((UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>)>
(UserId, DeviceIdBox, String), {
ToDeviceEvent<RoomKeyRequestToDeviceEventContent>,
)> {
self.requests_ids_waiting self.requests_ids_waiting
.remove(&(user_id.to_owned(), device_id.into())) .remove(&(user_id.to_owned(), device_id.into()))
.map(|(_, request_ids)| { .map(|(_, request_ids)| {
@ -204,12 +202,7 @@ fn wrap_key_request_content(
Ok(OutgoingRequest { Ok(OutgoingRequest {
request_id: id, request_id: id,
request: Arc::new( request: Arc::new(
ToDeviceRequest { ToDeviceRequest { event_type: EventType::RoomKeyRequest, txn_id: id, messages }.into(),
event_type: EventType::RoomKeyRequest,
txn_id: id,
messages,
}
.into(),
), ),
}) })
} }
@ -241,10 +234,7 @@ impl KeyRequestMachine {
.await? .await?
.into_iter() .into_iter()
.filter(|i| !i.sent_out) .filter(|i| !i.sent_out)
.map(|info| { .map(|info| info.to_request(self.device_id()).map_err(CryptoStoreError::from))
info.to_request(self.device_id())
.map_err(CryptoStoreError::from)
})
.collect() .collect()
} }
@ -262,11 +252,8 @@ impl KeyRequestMachine {
&self, &self,
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> { ) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?; let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> = self let key_forwards: Vec<OutgoingRequest> =
.outgoing_to_device_requests self.outgoing_to_device_requests.iter().map(|i| i.value().clone()).collect();
.iter()
.map(|i| i.value().clone())
.collect();
key_requests.extend(key_forwards); key_requests.extend(key_forwards);
Ok(key_requests) Ok(key_requests)
@ -281,8 +268,7 @@ impl KeyRequestMachine {
let device_id = event.content.requesting_device_id.clone(); let device_id = event.content.requesting_device_id.clone();
let request_id = event.content.request_id.clone(); let request_id = event.content.request_id.clone();
self.incoming_key_requests self.incoming_key_requests.insert((sender, device_id, request_id), event.clone());
.insert((sender, device_id, request_id), event.clone());
} }
/// Handle all the incoming key requests that are queued up and empty our /// Handle all the incoming key requests that are queued up and empty our
@ -401,10 +387,8 @@ impl KeyRequestMachine {
return Ok(None); return Ok(None);
}; };
let device = self let device =
.store self.store.get_device(&event.sender, &event.content.requesting_device_id).await?;
.get_device(&event.sender, &event.content.requesting_device_id)
.await?;
if let Some(device) = device { if let Some(device) = device {
match self.should_share_key(&device, &session).await { match self.should_share_key(&device, &session).await {
@ -461,17 +445,13 @@ impl KeyRequestMachine {
device: &Device, device: &Device,
message_index: Option<u32>, message_index: Option<u32>,
) -> OlmResult<Session> { ) -> OlmResult<Session> {
let (used_session, content) = device let (used_session, content) =
.encrypt_session(session.clone(), message_index) device.encrypt_session(session.clone(), message_index).await?;
.await?;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()), DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?, to_raw_value(&content)?,
); );
@ -479,11 +459,7 @@ impl KeyRequestMachine {
let request = OutgoingRequest { let request = OutgoingRequest {
request_id: id, request_id: id,
request: Arc::new( request: Arc::new(
ToDeviceRequest { ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages }
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
}
.into(), .into(),
), ),
}; };
@ -542,8 +518,8 @@ impl KeyRequestMachine {
} else { } else {
Err(KeyshareDecision::OutboundSessionNotShared) Err(KeyshareDecision::OutboundSessionNotShared)
} }
// Else just check if it's one of our own devices that requested the key and // Else just check if it's one of our own devices that requested the key
// check if the device is trusted. // and check if the device is trusted.
} else if device.user_id() == self.user_id() { } else if device.user_id() == self.user_id() {
own_device_check() own_device_check()
// Otherwise, there's not enough info to decide if we can safely share // Otherwise, there's not enough info to decide if we can safely share
@ -711,9 +687,7 @@ impl KeyRequestMachine {
/// Delete the given outgoing key info. /// Delete the given outgoing key info.
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> { async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
self.store self.store.delete_outgoing_key_request(info.request_id).await
.delete_outgoing_key_request(info.request_id)
.await
} }
/// Mark the outgoing request as sent. /// Mark the outgoing request as sent.
@ -736,20 +710,15 @@ impl KeyRequestMachine {
/// This will queue up a request cancelation. /// This will queue up a request cancelation.
async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> { async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
// TODO perhaps only remove the key info if the first known index is 0. // TODO perhaps only remove the key info if the first known index is 0.
trace!( trace!("Successfully received a forwarded room key for {:#?}", key_info);
"Successfully received a forwarded room key for {:#?}",
key_info
);
self.outgoing_to_device_requests self.outgoing_to_device_requests.remove(&key_info.request_id);
.remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler // TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction. // can delete it in one transaction.
self.delete_key_info(&key_info).await?; self.delete_key_info(&key_info).await?;
let request = key_info.to_cancelation(self.device_id())?; let request = key_info.to_cancelation(self.device_id())?;
self.outgoing_to_device_requests self.outgoing_to_device_requests.insert(request.request_id, request);
.insert(request.request_id, request);
Ok(()) Ok(())
} }
@ -801,10 +770,7 @@ impl KeyRequestMachine {
); );
} }
Ok(( Ok((Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())), session))
Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())),
session,
))
} else { } else {
info!( info!(
"Received a forwarded room key from {}, but no key info was found.", "Received a forwarded room key from {}, but no key info was found.",
@ -919,11 +885,7 @@ mod test {
async fn create_machine() { async fn create_machine() {
let machine = get_machine().await; let machine = get_machine().await;
assert!(machine assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
} }
#[async_test] #[async_test]
@ -931,16 +893,10 @@ mod test {
let machine = get_machine().await; let machine = get_machine().await;
let account = account(); let account = account();
let (_, session) = account let (_, session) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
assert!(machine assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
let (cancel, request) = machine let (cancel, request) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id()) .request_key(session.room_id(), &session.sender_key, session.session_id())
.await .await
@ -948,10 +904,7 @@ mod test {
assert!(cancel.is_none()); assert!(cancel.is_none());
machine machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
let (cancel, _) = machine let (cancel, _) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id()) .request_key(session.room_id(), &session.sender_key, session.session_id())
@ -972,16 +925,10 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap(); machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account let (_, session) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
assert!(machine assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
session.room_id(), session.room_id(),
@ -990,15 +937,8 @@ mod test {
) )
.await .await
.unwrap(); .unwrap();
assert!(!machine assert!(!machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests() assert_eq!(machine.outgoing_to_device_requests().await.unwrap().len(), 1);
.await
.unwrap()
.is_empty());
assert_eq!(
machine.outgoing_to_device_requests().await.unwrap().len(),
1
);
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
@ -1014,15 +954,8 @@ mod test {
let request = requests.get(0).unwrap(); let request = requests.get(0).unwrap();
machine machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
.mark_outgoing_request_as_sent(request.request_id) assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.await
.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
} }
#[async_test] #[async_test]
@ -1037,10 +970,8 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap(); machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account let (_, session) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
session.room_id(), session.room_id(),
@ -1060,10 +991,7 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent { let mut event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
assert!( assert!(
machine machine
@ -1078,19 +1006,13 @@ mod test {
.is_none() .is_none()
); );
let (_, first_session) = machine let (_, first_session) =
.receive_forwarded_room_key(&session.sender_key, &mut event) machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
.await
.unwrap();
let first_session = first_session.unwrap(); let first_session = first_session.unwrap();
assert_eq!(first_session.first_known_index(), 10); assert_eq!(first_session.first_known_index(), 10);
machine machine.store.save_inbound_group_sessions(&[first_session.clone()]).await.unwrap();
.store
.save_inbound_group_sessions(&[first_session.clone()])
.await
.unwrap();
// Get the cancel request. // Get the cancel request.
let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let request = machine.outgoing_to_device_requests.iter().next().unwrap();
@ -1110,24 +1032,16 @@ mod test {
let requests = machine.outgoing_to_device_requests().await.unwrap(); let requests = machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0]; let request = &requests[0];
machine machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
let export = session.export_at_index(15).await; let export = session.export_at_index(15).await;
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent { let mut event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
let (_, second_session) = machine let (_, second_session) =
.receive_forwarded_room_key(&session.sender_key, &mut event) machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
.await
.unwrap();
assert!(second_session.is_none()); assert!(second_session.is_none());
@ -1135,15 +1049,10 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent { let mut event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
let (_, second_session) = machine let (_, second_session) =
.receive_forwarded_room_key(&session.sender_key, &mut event) machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
.await
.unwrap();
assert_eq!(second_session.unwrap().first_known_index(), 0); assert_eq!(second_session.unwrap().first_known_index(), 0);
} }
@ -1153,17 +1062,11 @@ mod test {
let machine = get_machine().await; let machine = get_machine().await;
let account = account(); let account = account();
let own_device = machine let own_device =
.store machine.store.get_device(&alice_id(), &alice_device_id()).await.unwrap().unwrap();
.get_device(&alice_id(), &alice_device_id())
.await
.unwrap()
.unwrap();
let (outbound, inbound) = account let (outbound, inbound) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
// We don't share keys with untrusted devices. // We don't share keys with untrusted devices.
assert_eq!( assert_eq!(
@ -1175,20 +1078,13 @@ mod test {
); );
own_device.set_trust_state(LocalTrust::Verified); own_device.set_trust_state(LocalTrust::Verified);
// Now we do want to share the keys. // Now we do want to share the keys.
assert!(machine assert!(machine.should_share_key(&own_device, &inbound).await.is_ok());
.should_share_key(&own_device, &inbound)
.await
.is_ok());
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
machine.store.save_devices(&[bob_device]).await.unwrap(); machine.store.save_devices(&[bob_device]).await.unwrap();
let bob_device = machine let bob_device =
.store machine.store.get_device(&bob_id(), &bob_device_id()).await.unwrap().unwrap();
.get_device(&bob_id(), &bob_device_id())
.await
.unwrap()
.unwrap();
// We don't share sessions with other user's devices if no outbound // We don't share sessions with other user's devices if no outbound
// session was provided. // session was provided.
@ -1231,17 +1127,12 @@ mod test {
// We now share the session, since it was shared before. // We now share the session, since it was shared before.
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
assert!(machine assert!(machine.should_share_key(&bob_device, &inbound).await.is_ok());
.should_share_key(&bob_device, &inbound)
.await
.is_ok());
// But we don't share some other session that doesn't match our outbound // But we don't share some other session that doesn't match our outbound
// session // session
let (_, other_inbound) = account let (_, other_inbound) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
assert_eq!( assert_eq!(
machine machine
@ -1255,10 +1146,7 @@ mod test {
#[async_test] #[async_test]
async fn key_share_cycle() { async fn key_share_cycle() {
let alice_machine = get_machine().await; let alice_machine = get_machine().await;
let alice_account = Account { let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
inner: account(),
store: alice_machine.store.clone(),
};
let bob_machine = bob_machine(); let bob_machine = bob_machine();
let bob_account = bob_account(); let bob_account = bob_account();
@ -1268,11 +1156,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys // We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
alice_machine alice_machine.store.save_devices(&[alice_device]).await.unwrap();
.store
.save_devices(&[alice_device])
.await
.unwrap();
// Create Olm sessions for our two accounts. // Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1282,37 +1166,15 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session. // Populate our stores with Olm sessions and a Megolm session.
alice_machine alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
.store alice_machine.store.save_devices(&[bob_device]).await.unwrap();
.save_sessions(&[alice_session]) bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
.await bob_machine.store.save_devices(&[alice_device]).await.unwrap();
.unwrap();
alice_machine
.store
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
let (group_session, inbound_group_session) = bob_account let (group_session, inbound_group_session) =
.create_group_session_pair_with_defaults(&room_id()) bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
bob_machine bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
// Alice wants to request the outbound group session from bob. // Alice wants to request the outbound group session from bob.
alice_machine alice_machine
@ -1326,9 +1188,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id()); group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store. // Put the outbound session into bobs store.
bob_machine bob_machine.outbound_group_sessions.insert(group_session.clone());
.outbound_group_sessions
.insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1346,15 +1206,9 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap(); serde_json::from_str(content.get()).unwrap();
alice_machine alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
// Bob doesn't have any outgoing requests. // Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests.is_empty()); assert!(bob_machine.outgoing_to_device_requests.is_empty());
@ -1383,10 +1237,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: bob_id(), content };
sender: bob_id(),
content,
};
// Check that alice doesn't have the session. // Check that alice doesn't have the session.
assert!(alice_machine assert!(alice_machine
@ -1407,11 +1258,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }
@ -1434,10 +1281,7 @@ mod test {
#[async_test] #[async_test]
async fn key_share_cycle_without_session() { async fn key_share_cycle_without_session() {
let alice_machine = get_machine().await; let alice_machine = get_machine().await;
let alice_account = Account { let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
inner: account(),
store: alice_machine.store.clone(),
};
let bob_machine = bob_machine(); let bob_machine = bob_machine();
let bob_account = bob_account(); let bob_account = bob_account();
@ -1447,11 +1291,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys // We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
alice_machine alice_machine.store.save_devices(&[alice_device]).await.unwrap();
.store
.save_devices(&[alice_device])
.await
.unwrap();
// Create Olm sessions for our two accounts. // Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1461,27 +1301,13 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session. // Populate our stores with Olm sessions and a Megolm session.
alice_machine alice_machine.store.save_devices(&[bob_device]).await.unwrap();
.store bob_machine.store.save_devices(&[alice_device]).await.unwrap();
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
let (group_session, inbound_group_session) = bob_account let (group_session, inbound_group_session) =
.create_group_session_pair_with_defaults(&room_id()) bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
bob_machine bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
// Alice wants to request the outbound group session from bob. // Alice wants to request the outbound group session from bob.
alice_machine alice_machine
@ -1495,9 +1321,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id()); group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store. // Put the outbound session into bobs store.
bob_machine bob_machine.outbound_group_sessions.insert(group_session.clone());
.outbound_group_sessions
.insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1515,22 +1339,12 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap(); serde_json::from_str(content.get()).unwrap();
alice_machine alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
// Bob doesn't have any outgoing requests. // Bob doesn't have any outgoing requests.
assert!(bob_machine assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.users_for_key_claim.is_empty());
assert!(bob_machine.wait_queue.is_empty()); assert!(bob_machine.wait_queue.is_empty());
@ -1538,35 +1352,19 @@ mod test {
bob_machine.receive_incoming_key_request(&event); bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob doens't have an outgoing requests since we're lacking a session. // Bob doens't have an outgoing requests since we're lacking a session.
assert!(bob_machine assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(!bob_machine.users_for_key_claim.is_empty()); assert!(!bob_machine.users_for_key_claim.is_empty());
assert!(!bob_machine.wait_queue.is_empty()); assert!(!bob_machine.wait_queue.is_empty());
// We create a session now. // We create a session now.
alice_machine alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
.store bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
.save_sessions(&[alice_session])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
bob_machine.retry_keyshare(&alice_id(), &alice_device_id()); bob_machine.retry_keyshare(&alice_id(), &alice_device_id());
assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.users_for_key_claim.is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob now has an outgoing requests. // Bob now has an outgoing requests.
assert!(!bob_machine assert!(!bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.wait_queue.is_empty()); assert!(bob_machine.wait_queue.is_empty());
// Get the request and convert it to a encrypted to-device event. // Get the request and convert it to a encrypted to-device event.
@ -1588,10 +1386,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: bob_id(), content };
sender: bob_id(),
content,
};
// Check that alice doesn't have the session. // Check that alice doesn't have the session.
assert!(alice_machine assert!(alice_machine
@ -1612,11 +1407,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }

View File

@ -148,19 +148,12 @@ impl OlmMachine {
let store = Arc::new(store); let store = Arc::new(store);
let verification_machine = let verification_machine =
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone()); VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
let store = Store::new( let store =
user_id.clone(), Store::new(user_id.clone(), user_identity.clone(), store, verification_machine.clone());
user_identity.clone(),
store,
verification_machine.clone(),
);
let device_id: Arc<DeviceIdBox> = Arc::new(device_id); let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
let users_for_key_claim = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new());
let account = Account { let account = Account { inner: account, store: store.clone() };
inner: account,
store: store.clone(),
};
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
@ -244,9 +237,7 @@ impl OlmMachine {
} }
}; };
Ok(OlmMachine::new_helper( Ok(OlmMachine::new_helper(&user_id, device_id, store, account, identity))
&user_id, device_id, store, account, identity,
))
} }
/// Create a new machine with the default crypto store. /// Create a new machine with the default crypto store.
@ -296,19 +287,16 @@ impl OlmMachine {
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> { pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
let mut requests = Vec::new(); let mut requests = Vec::new();
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { if let Some(r) = self
request_id: Uuid::new_v4(), .keys_for_upload()
request: Arc::new(r.into()), .await
}) { .map(|r| OutgoingRequest { request_id: Uuid::new_v4(), request: Arc::new(r.into()) })
{
requests.push(r); requests.push(r);
} }
for request in self for request in
.identity_manager self.identity_manager.users_for_key_query().await.into_iter().map(|r| OutgoingRequest {
.users_for_key_query()
.await
.into_iter()
.map(|r| OutgoingRequest {
request_id: Uuid::new_v4(), request_id: Uuid::new_v4(),
request: Arc::new(r.into()), request: Arc::new(r.into()),
}) })
@ -318,12 +306,7 @@ impl OlmMachine {
requests.append(&mut self.outgoing_to_device_requests()); requests.append(&mut self.outgoing_to_device_requests());
requests.append(&mut self.verification_machine.outgoing_room_message_requests()); requests.append(&mut self.verification_machine.outgoing_room_message_requests());
requests.append( requests.append(&mut self.key_request_machine.outgoing_to_device_requests().await?);
&mut self
.key_request_machine
.outgoing_to_device_requests()
.await?,
);
Ok(requests) Ok(requests)
} }
@ -374,10 +357,7 @@ impl OlmMachine {
let identity = self.user_identity.lock().await; let identity = self.user_identity.lock().await;
identity.mark_as_shared(); identity.mark_as_shared();
let changes = Changes { let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
private_identity: Some(identity.clone()),
..Default::default()
};
self.store.save_changes(changes).await self.store.save_changes(changes).await
} }
@ -407,10 +387,7 @@ impl OlmMachine {
); );
let changes = Changes { let changes = Changes {
identities: IdentityChanges { identities: IdentityChanges { new: vec![public.into()], ..Default::default() },
new: vec![public.into()],
..Default::default()
},
private_identity: Some(identity.clone()), private_identity: Some(identity.clone()),
..Default::default() ..Default::default()
}; };
@ -422,10 +399,8 @@ impl OlmMachine {
info!("Trying to upload the existing cross signing identity"); info!("Trying to upload the existing cross signing identity");
let request = identity.as_upload_request().await; let request = identity.as_upload_request().await;
// TODO remove this expect. // TODO remove this expect.
let signature_request = identity let signature_request =
.sign_account(&self.account) identity.sign_account(&self.account).await.expect("Can't sign device keys");
.await
.expect("Can't sign device keys");
Ok((request, signature_request)) Ok((request, signature_request))
} }
} }
@ -519,9 +494,7 @@ impl OlmMachine {
/// ///
/// * `response` - The response containing the claimed one-time keys. /// * `response` - The response containing the claimed one-time keys.
async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> { async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
self.session_manager self.session_manager.receive_keys_claim_response(response).await
.receive_keys_claim_response(response)
.await
} }
/// Receive a successful keys query response. /// Receive a successful keys query response.
@ -537,9 +510,7 @@ impl OlmMachine {
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager self.identity_manager.receive_keys_query_response(response).await
.receive_keys_query_response(response)
.await
} }
/// Get a request to upload E2EE keys to the server. /// Get a request to upload E2EE keys to the server.
@ -677,9 +648,7 @@ impl OlmMachine {
/// Returns true if a session was invalidated, false if there was no session /// Returns true if a session was invalidated, false if there was no session
/// to invalidate. /// to invalidate.
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> { pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
self.group_session_manager self.group_session_manager.invalidate_group_session(room_id).await
.invalidate_group_session(room_id)
.await
} }
/// Get to-device requests to share a group session with users in a room. /// Get to-device requests to share a group session with users in a room.
@ -696,9 +665,7 @@ impl OlmMachine {
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> { ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.group_session_manager self.group_session_manager.share_group_session(room_id, users, encryption_settings).await
.share_group_session(room_id, users, encryption_settings)
.await
} }
/// Receive and properly handle a decrypted to-device event. /// Receive and properly handle a decrypted to-device event.
@ -717,18 +684,15 @@ impl OlmMachine {
let event = match decrypted.event.deserialize() { let event = match decrypted.event.deserialize() {
Ok(e) => e, Ok(e) => e,
Err(e) => { Err(e) => {
warn!( warn!("Decrypted to-device event failed to be parsed correctly {:?}", e);
"Decrypted to-device event failed to be parsed correctly {:?}",
e
);
return Ok((None, None)); return Ok((None, None));
} }
}; };
match event { match event {
AnyToDeviceEvent::RoomKey(mut e) => Ok(self AnyToDeviceEvent::RoomKey(mut e) => {
.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e) Ok(self.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e).await?)
.await?), }
AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self
.key_request_machine .key_request_machine
.receive_forwarded_room_key(&decrypted.sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
@ -754,14 +718,9 @@ impl OlmMachine {
/// Mark an outgoing to-device requests as sent. /// Mark an outgoing to-device requests as sent.
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
self.verification_machine.mark_request_as_sent(request_id); self.verification_machine.mark_request_as_sent(request_id);
self.key_request_machine self.key_request_machine.mark_outgoing_request_as_sent(*request_id).await?;
.mark_outgoing_request_as_sent(*request_id) self.group_session_manager.mark_request_as_sent(request_id).await?;
.await?; self.session_manager.mark_outgoing_request_as_sent(request_id);
self.group_session_manager
.mark_request_as_sent(request_id)
.await?;
self.session_manager
.mark_outgoing_request_as_sent(request_id);
Ok(()) Ok(())
} }
@ -810,10 +769,8 @@ impl OlmMachine {
// Always save the account, a new session might get created which also // Always save the account, a new session might get created which also
// touches the account. // touches the account.
let mut changes = Changes { let mut changes =
account: Some(self.account.inner.clone()), Changes { account: Some(self.account.inner.clone()), ..Default::default() };
..Default::default()
};
self.update_one_time_key_count(one_time_keys_counts).await; self.update_one_time_key_count(one_time_keys_counts).await;
@ -830,10 +787,7 @@ impl OlmMachine {
Ok(e) => e, Ok(e) => e,
Err(e) => { Err(e) => {
// Skip invalid events. // Skip invalid events.
warn!( warn!("Received an invalid to-device event {:?} {:?}", e, raw_event);
"Received an invalid to-device event {:?} {:?}",
e, raw_event
);
continue; continue;
} }
}; };
@ -845,10 +799,7 @@ impl OlmMachine {
let decrypted = match self.decrypt_to_device_event(&e).await { let decrypted = match self.decrypt_to_device_event(&e).await {
Ok(e) => e, Ok(e) => e,
Err(err) => { Err(err) => {
warn!( warn!("Failed to decrypt to-device event from {} {}", e.sender, err);
"Failed to decrypt to-device event from {} {}",
e.sender, err
);
if let OlmError::SessionWedged(sender, curve_key) = err { if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self if let Err(e) = self
@ -903,10 +854,7 @@ impl OlmMachine {
events.push(raw_event); events.push(raw_event);
} }
let changed_sessions = self let changed_sessions = self.key_request_machine.collect_incoming_key_requests().await?;
.key_request_machine
.collect_incoming_key_requests()
.await?;
changes.sessions.extend(changed_sessions); changes.sessions.extend(changed_sessions);
@ -1023,25 +971,16 @@ impl OlmMachine {
// TODO check if this is from a verified device. // TODO check if this is from a verified device.
let (decrypted_event, _) = session.decrypt(event).await?; let (decrypted_event, _) = session.decrypt(event).await?;
trace!( trace!("Successfully decrypted a Megolm event {:?}", decrypted_event);
"Successfully decrypted a Megolm event {:?}",
decrypted_event
);
if let Ok(e) = decrypted_event.deserialize() { if let Ok(e) = decrypted_event.deserialize() {
self.verification_machine self.verification_machine.receive_room_event(room_id, &e).await?;
.receive_room_event(room_id, &e)
.await?;
} }
let encryption_info = self let encryption_info =
.get_encryption_info(&session, &event.sender, &content.device_id) self.get_encryption_info(&session, &event.sender, &content.device_id).await?;
.await?;
Ok(SyncRoomEvent { Ok(SyncRoomEvent { encryption_info: Some(encryption_info), event: decrypted_event })
encryption_info: Some(encryption_info),
event: decrypted_event,
})
} }
/// Update the tracked users. /// Update the tracked users.
@ -1197,17 +1136,11 @@ impl OlmMachine {
let num_sessions = sessions.len(); let num_sessions = sessions.len();
let changes = Changes { let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
inbound_group_sessions: sessions,
..Default::default()
};
self.store.save_changes(changes).await?; self.store.save_changes(changes).await?;
info!( info!("Successfully imported {} inbound group sessions", num_sessions);
"Successfully imported {} inbound group sessions",
num_sessions
);
Ok((num_sessions, total_sessions)) Ok((num_sessions, total_sessions))
} }
@ -1318,10 +1251,7 @@ pub(crate) mod test {
} }
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> { pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder() Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
} }
fn keys_upload_response() -> upload_keys::Response { fn keys_upload_response() -> upload_keys::Response {
@ -1340,15 +1270,7 @@ pub(crate) mod test {
let to_device_request = &requests[0]; let to_device_request = &requests[0];
let content: Raw<EncryptedEventContent> = serde_json::from_str( let content: Raw<EncryptedEventContent> = serde_json::from_str(
to_device_request to_device_request.messages.values().next().unwrap().values().next().unwrap().get(),
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
) )
.unwrap(); .unwrap();
@ -1358,15 +1280,9 @@ pub(crate) mod test {
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.account.inner.update_uploaded_key_count(0); machine.account.inner.update_uploaded_key_count(0);
let request = machine let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload");
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
let response = keys_upload_response(); let response = keys_upload_response();
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
(machine, request.one_time_keys.unwrap()) (machine, request.one_time_keys.unwrap())
} }
@ -1375,10 +1291,7 @@ pub(crate) mod test {
let (machine, otk) = get_prepared_machine().await; let (machine, otk) = get_prepared_machine().await;
let response = keys_query_response(); let response = keys_query_response();
machine machine.receive_keys_query_response(&response).await.unwrap();
.receive_keys_query_response(&response)
.await
.unwrap();
(machine, otk) (machine, otk)
} }
@ -1421,28 +1334,15 @@ pub(crate) mod test {
async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) { async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) {
let (alice, bob) = get_machine_pair_with_session().await; let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let (session, content) = bob_device let (session, content) = bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap();
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap();
alice.store.save_sessions(&[session]).await.unwrap(); alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: alice.user_id().clone(), content };
sender: alice.user_id().clone(),
content,
};
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
(alice, bob) (alice, bob)
} }
@ -1458,34 +1358,18 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
let mut response = keys_upload_response(); let mut response = keys_upload_response();
response response.one_time_key_counts.remove(&DeviceKeyAlgorithm::SignedCurve25519).unwrap();
.one_time_key_counts
.remove(&DeviceKeyAlgorithm::SignedCurve25519)
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
response response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
.one_time_key_counts machine.receive_keys_upload_response(&response).await.unwrap();
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
response response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
.one_time_key_counts machine.receive_keys_upload_response(&response).await.unwrap();
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(!machine.should_upload_keys().await); assert!(!machine.should_upload_keys().await);
} }
@ -1497,20 +1381,12 @@ pub(crate) mod test {
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
assert!(machine.account.generate_one_time_keys().await.is_ok()); assert!(machine.account.generate_one_time_keys().await.is_ok());
response response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
.one_time_key_counts machine.receive_keys_upload_response(&response).await.unwrap();
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.account.generate_one_time_keys().await.is_err()); assert!(machine.account.generate_one_time_keys().await.is_err());
} }
@ -1537,14 +1413,8 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
let room_id = room_id!("!test:example.org"); let room_id = room_id!("!test:example.org");
machine machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
.create_outbound_group_session_with_defaults(&room_id) assert!(machine.group_session_manager.get_outbound_group_session(&room_id).is_some());
.await
.unwrap();
assert!(machine
.group_session_manager
.get_outbound_group_session(&room_id)
.is_some());
machine.invalidate_group_session(&room_id).await.unwrap(); machine.invalidate_group_session(&room_id).await.unwrap();
@ -1600,10 +1470,8 @@ pub(crate) mod test {
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let mut request = machine let mut request =
.keys_for_upload() machine.keys_for_upload().await.expect("Can't prepare initial key upload");
.await
.expect("Can't prepare initial key upload");
let utility = Utility::new(); let utility = Utility::new();
let ret = utility.verify_json( let ret = utility.verify_json(
@ -1626,15 +1494,10 @@ pub(crate) mod test {
let mut response = keys_upload_response(); let mut response = keys_upload_response();
response.one_time_key_counts.insert( response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519, DeviceKeyAlgorithm::SignedCurve25519,
(request.one_time_keys.unwrap().len() as u64) (request.one_time_keys.unwrap().len() as u64).try_into().unwrap(),
.try_into()
.unwrap(),
); );
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
let ret = machine.keys_for_upload().await; let ret = machine.keys_for_upload().await;
assert!(ret.is_none()); assert!(ret.is_none());
@ -1650,17 +1513,9 @@ pub(crate) mod test {
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap(); let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
assert!(alice_devices.devices().peekable().peek().is_none()); assert!(alice_devices.devices().peekable().peek().is_none());
machine machine.receive_keys_query_response(&response).await.unwrap();
.receive_keys_query_response(&response)
.await
.unwrap();
let device = machine let device = machine.store.get_device(&alice_id, alice_device_id).await.unwrap().unwrap();
.store
.get_device(&alice_id, alice_device_id)
.await
.unwrap()
.unwrap();
assert_eq!(device.user_id(), &alice_id); assert_eq!(device.user_id(), &alice_id);
assert_eq!(device.device_id(), alice_device_id); assert_eq!(device.device_id(), alice_device_id);
} }
@ -1672,11 +1527,8 @@ pub(crate) mod test {
let alice = alice_id(); let alice = alice_id();
let alice_device = alice_device_id(); let alice_device = alice_device_id();
let (_, missing_sessions) = machine let (_, missing_sessions) =
.get_missing_sessions(&mut [alice.clone()].iter()) machine.get_missing_sessions(&mut [alice.clone()].iter()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert!(missing_sessions.one_time_keys.contains_key(&alice)); assert!(missing_sessions.one_time_keys.contains_key(&alice));
let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap(); let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap();
@ -1699,10 +1551,7 @@ pub(crate) mod test {
let response = claim_keys::Response::new(one_time_keys); let response = claim_keys::Response::new(one_time_keys);
alice_machine alice_machine.receive_keys_claim_response(&response).await.unwrap();
.receive_keys_claim_response(&response)
.await
.unwrap();
let session = alice_machine let session = alice_machine
.store .store
@ -1718,28 +1567,14 @@ pub(crate) mod test {
async fn test_olm_encryption() { async fn test_olm_encryption() {
let (alice, bob) = get_machine_pair_with_session().await; let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent {
sender: alice.user_id().clone(), sender: alice.user_id().clone(),
content: bob_device content: bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap().1,
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap()
.1,
}; };
let event = bob let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap();
.decrypt_to_device_event(&event)
.await
.unwrap()
.event
.deserialize()
.unwrap();
if let AnyToDeviceEvent::Dummy(e) = event { if let AnyToDeviceEvent::Dummy(e) = event {
assert_eq!(&e.sender, alice.user_id()); assert_eq!(&e.sender, alice.user_id());
@ -1768,17 +1603,12 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests), content: to_device_requests_to_content(to_device_requests),
}; };
let alice_session = alice let alice_session =
.group_session_manager alice.group_session_manager.get_outbound_group_session(&room_id).unwrap();
.get_outbound_group_session(&room_id)
.unwrap();
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
bob.store bob.store
.save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()]) .save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()])
.await .await
@ -1823,25 +1653,16 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests), content: to_device_requests_to_content(to_device_requests),
}; };
let group_session = bob let group_session =
.decrypt_to_device_event(&event) bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session;
.await bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
.unwrap()
.inbound_group_session;
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap();
let plaintext = "It is a secret to everybody"; let plaintext = "It is a secret to everybody";
let content = MessageEventContent::text_plain(plaintext); let content = MessageEventContent::text_plain(plaintext);
let encrypted_content = alice let encrypted_content = alice
.encrypt( .encrypt(&room_id, AnyMessageEventContent::RoomMessage(content.clone()))
&room_id,
AnyMessageEventContent::RoomMessage(content.clone()),
)
.await .await
.unwrap(); .unwrap();
@ -1853,13 +1674,8 @@ pub(crate) mod test {
unsigned: Unsigned::default(), unsigned: Unsigned::default(),
}; };
let decrypted_event = bob let decrypted_event =
.decrypt_room_event(&event, &room_id) bob.decrypt_room_event(&event, &room_id).await.unwrap().event.deserialize().unwrap();
.await
.unwrap()
.event
.deserialize()
.unwrap();
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent { if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent {
sender, sender,
@ -1898,10 +1714,7 @@ pub(crate) mod test {
let device_id = machine.device_id().to_owned(); let device_id = machine.device_id().to_owned();
let ed25519_key = machine.identity_keys().ed25519().to_owned(); let ed25519_key = machine.identity_keys().ed25519().to_owned();
machine machine.receive_keys_upload_response(&keys_upload_response()).await.unwrap();
.receive_keys_upload_response(&keys_upload_response())
.await
.unwrap();
drop(machine); drop(machine);
@ -1923,11 +1736,7 @@ pub(crate) mod test {
async fn interactive_verification() { async fn interactive_verification() {
let (alice, bob) = get_machine_pair_with_setup_sessions().await; let (alice, bob) = get_machine_pair_with_setup_sessions().await;
let bob_device = alice let bob_device = alice.get_device(bob.user_id(), bob.device_id()).await.unwrap().unwrap();
.get_device(bob.user_id(), bob.device_id())
.await
.unwrap()
.unwrap();
assert!(!bob_device.is_trusted()); assert!(!bob_device.is_trusted());
@ -1941,10 +1750,7 @@ pub(crate) mod test {
assert!(alice_sas.emoji().is_none()); assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none()); assert!(bob_sas.emoji().is_none());
let event = bob_sas let event = bob_sas.accept().map(|r| request_to_event(bob.user_id(), &r)).unwrap();
.accept()
.map(|r| request_to_event(bob.user_id(), &r))
.unwrap();
alice.handle_verification_event(&event).await; alice.handle_verification_event(&event).await;
@ -1991,11 +1797,8 @@ pub(crate) mod test {
assert!(alice_sas.is_done()); assert!(alice_sas.is_done());
assert!(bob_device.is_trusted()); assert!(bob_device.is_trusted());
let alice_device = bob let alice_device =
.get_device(alice.user_id(), alice.device_id()) bob.get_device(alice.user_id(), alice.device_id()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert!(!alice_device.is_trusted()); assert!(!alice_device.is_trusted());
bob.handle_verification_event(&event).await; bob.handle_verification_event(&event).await;

View File

@ -139,10 +139,8 @@ impl Account {
// Try to find a ciphertext that was meant for our device. // Try to find a ciphertext that was meant for our device.
if let Some(ciphertext) = own_ciphertext { if let Some(ciphertext) = own_ciphertext {
let message_type: u8 = ciphertext let message_type: u8 =
.message_type ciphertext.message_type.try_into().map_err(|_| EventError::UnsupportedOlmType)?;
.try_into()
.map_err(|_| EventError::UnsupportedOlmType)?;
let sha = Sha256::new() let sha = Sha256::new()
.chain(&content.sender_key) .chain(&content.sender_key)
@ -160,10 +158,8 @@ impl Account {
.map_err(|_| EventError::UnsupportedOlmType)?; .map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it. // Decrypt the OlmMessage and get a Ruma event out of it.
let (session, event, signing_key) = match self let (session, event, signing_key) =
.decrypt_olm_message(&event.sender, &content.sender_key, message) match self.decrypt_olm_message(&event.sender, &content.sender_key, message).await {
.await
{
Ok(d) => d, Ok(d) => d,
Err(OlmError::SessionWedged(user_id, sender_key)) => { Err(OlmError::SessionWedged(user_id, sender_key)) => {
if self.store.is_message_known(&message_hash).await? { if self.store.is_message_known(&message_hash).await? {
@ -208,9 +204,8 @@ impl Account {
} }
self.inner.mark_as_shared(); self.inner.mark_as_shared();
let one_time_key_count = response let one_time_key_count =
.one_time_key_counts response.one_time_key_counts.get(&DeviceKeyAlgorithm::SignedCurve25519);
.get(&DeviceKeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
debug!( debug!(
@ -295,9 +290,8 @@ impl Account {
message: OlmMessage, message: OlmMessage,
) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> { ) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session. // First try to decrypt using an existing session.
let (session, plaintext) = if let Some(d) = self let (session, plaintext) = if let Some(d) =
.try_decrypt_olm_message(sender, sender_key, &message) self.try_decrypt_olm_message(sender, sender_key, &message).await?
.await?
{ {
// Decryption succeeded, de-structure the session/plaintext out of // Decryption succeeded, de-structure the session/plaintext out of
// the Option. // the Option.
@ -314,19 +308,13 @@ impl Account {
available sessions {} {}", available sessions {} {}",
sender, sender_key sender, sender_key
); );
return Err(OlmError::SessionWedged( return Err(OlmError::SessionWedged(sender.to_owned(), sender_key.to_owned()));
sender.to_owned(),
sender_key.to_owned(),
));
} }
OlmMessage::PreKey(m) => { OlmMessage::PreKey(m) => {
// Create the new session. // Create the new session.
let session = match self let session =
.inner match self.inner.create_inbound_session(sender_key, m.clone()).await {
.create_inbound_session(sender_key, m.clone())
.await
{
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
warn!( warn!(
@ -426,9 +414,8 @@ impl Account {
return Err(EventError::MissmatchedKeys.into()); return Err(EventError::MissmatchedKeys.into());
} }
let signing_key = keys let signing_key =
.get(&DeviceKeyAlgorithm::Ed25519) keys.get(&DeviceKeyAlgorithm::Ed25519).ok_or(EventError::MissingSigningKey)?;
.ok_or(EventError::MissingSigningKey)?;
Ok(( Ok((
Raw::from(serde_json::from_value::<AnyToDeviceEvent>(decrypted_json)?), Raw::from(serde_json::from_value::<AnyToDeviceEvent>(decrypted_json)?),
@ -545,8 +532,7 @@ impl ReadOnlyAccount {
/// * `new_count` - The new count that was reported by the server. /// * `new_count` - The new count that was reported by the server.
pub(crate) fn update_uploaded_key_count(&self, new_count: u64) { pub(crate) fn update_uploaded_key_count(&self, new_count: u64) {
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX); let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
self.uploaded_signed_key_count self.uploaded_signed_key_count.store(key_count, Ordering::Relaxed);
.store(key_count, Ordering::Relaxed);
} }
/// Get the currently known uploaded key count. /// Get the currently known uploaded key count.
@ -629,19 +615,12 @@ impl ReadOnlyAccount {
/// Returns None if no keys need to be uploaded. /// Returns None if no keys need to be uploaded.
pub(crate) async fn keys_for_upload( pub(crate) async fn keys_for_upload(
&self, &self,
) -> Option<( ) -> Option<(Option<DeviceKeys>, Option<BTreeMap<DeviceKeyId, OneTimeKey>>)> {
Option<DeviceKeys>,
Option<BTreeMap<DeviceKeyId, OneTimeKey>>,
)> {
if !self.should_upload_keys().await { if !self.should_upload_keys().await {
return None; return None;
} }
let device_keys = if !self.shared() { let device_keys = if !self.shared() { Some(self.device_keys().await) } else { None };
Some(self.device_keys().await)
} else {
None
};
let one_time_keys = self.signed_one_time_keys().await.ok(); let one_time_keys = self.signed_one_time_keys().await.ok();
@ -664,7 +643,8 @@ impl ReadOnlyAccount {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `pickle_mode` - The mode that was used to pickle the account, either an /// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase. /// unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledAccount { pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledAccount {
let pickle = AccountPickle(self.inner.lock().await.pickle(pickle_mode)); let pickle = AccountPickle(self.inner.lock().await.pickle(pickle_mode));
@ -684,7 +664,8 @@ impl ReadOnlyAccount {
/// ///
/// * `pickle` - The pickled version of the Account. /// * `pickle` - The pickled version of the Account.
/// ///
/// * `pickle_mode` - The mode that was used to pickle the account, either an /// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase. /// unencrypted mode or an encrypted using passphrase.
pub fn from_pickle( pub fn from_pickle(
pickle: PickledAccount, pickle: PickledAccount,
@ -740,11 +721,7 @@ impl ReadOnlyAccount {
"keys": device_keys.keys, "keys": device_keys.keys,
}); });
device_keys device_keys.signatures.entry(self.user_id().clone()).or_insert_with(BTreeMap::new).insert(
.signatures
.entry(self.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id), DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(json_device_keys).await, self.sign_json(json_device_keys).await,
); );
@ -754,11 +731,7 @@ impl ReadOnlyAccount {
pub(crate) async fn bootstrap_cross_signing( pub(crate) async fn bootstrap_cross_signing(
&self, &self,
) -> ( ) -> (PrivateCrossSigningIdentity, UploadSigningKeysRequest, SignatureUploadRequest) {
PrivateCrossSigningIdentity,
UploadSigningKeysRequest,
SignatureUploadRequest,
) {
PrivateCrossSigningIdentity::new_with_account(self).await PrivateCrossSigningIdentity::new_with_account(self).await
} }
@ -871,8 +844,8 @@ impl ReadOnlyAccount {
/// # Arguments /// # Arguments
/// * `device` - The other account's device. /// * `device` - The other account's device.
/// ///
/// * `key_map` - A map from the algorithm and device id to the one-time key that the other /// * `key_map` - A map from the algorithm and device id to the one-time key
/// account created and shared with us. /// that the other account created and shared with us.
pub(crate) async fn create_outbound_session( pub(crate) async fn create_outbound_session(
&self, &self,
device: ReadOnlyDevice, device: ReadOnlyDevice,
@ -909,18 +882,14 @@ impl ReadOnlyAccount {
) )
})?; })?;
let curve_key = device let curve_key = device.get_key(DeviceKeyAlgorithm::Curve25519).ok_or_else(|| {
.get_key(DeviceKeyAlgorithm::Curve25519)
.ok_or_else(|| {
SessionCreationError::DeviceMissingCurveKey( SessionCreationError::DeviceMissingCurveKey(
device.user_id().to_owned(), device.user_id().to_owned(),
device.device_id().into(), device.device_id().into(),
) )
})?; })?;
self.create_outbound_session_helper(curve_key, &one_time_key) self.create_outbound_session_helper(curve_key, &one_time_key).await.map_err(|e| {
.await
.map_err(|e| {
SessionCreationError::OlmError( SessionCreationError::OlmError(
device.user_id().to_owned(), device.user_id().to_owned(),
device.device_id().into(), device.device_id().into(),
@ -944,17 +913,10 @@ impl ReadOnlyAccount {
their_identity_key: &str, their_identity_key: &str,
message: PreKeyMessage, message: PreKeyMessage,
) -> Result<Session, OlmSessionError> { ) -> Result<Session, OlmSessionError> {
let session = self let session =
.inner self.inner.lock().await.create_inbound_session_from(their_identity_key, message)?;
.lock()
.await
.create_inbound_session_from(their_identity_key, message)?;
self.inner self.inner.lock().await.remove_one_time_keys(&session).expect(
.lock()
.await
.remove_one_time_keys(&session)
.expect(
"Session was successfully created but the account doesn't hold a matching one-time key", "Session was successfully created but the account doesn't hold a matching one-time key",
); );
@ -1026,8 +988,7 @@ impl ReadOnlyAccount {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
self.create_group_session_pair(room_id, EncryptionSettings::default()) self.create_group_session_pair(room_id, EncryptionSettings::default()).await
.await
} }
#[cfg(test)] #[cfg(test)]
@ -1037,27 +998,19 @@ impl ReadOnlyAccount {
let device = ReadOnlyDevice::from_account(other).await; let device = ReadOnlyDevice::from_account(other).await;
let mut our_session = self let mut our_session =
.create_outbound_session(device.clone(), &one_time) self.create_outbound_session(device.clone(), &one_time).await.unwrap();
.await
.unwrap();
other.mark_keys_as_published().await; other.mark_keys_as_published().await;
let message = our_session let message = our_session.encrypt(&device, EventType::Dummy, json!({})).await.unwrap();
.encrypt(&device, EventType::Dummy, json!({}))
.await
.unwrap();
let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme { let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme {
c c
} else { } else {
panic!("Invalid encrypted event algorithm"); panic!("Invalid encrypted event algorithm");
}; };
let own_ciphertext = content let own_ciphertext = content.ciphertext.get(other.identity_keys.curve25519()).unwrap();
.ciphertext
.get(other.identity_keys.curve25519())
.unwrap();
let message_type: u8 = own_ciphertext.message_type.try_into().unwrap(); let message_type: u8 = own_ciphertext.message_type.try_into().unwrap();
let message = let message =

View File

@ -147,10 +147,8 @@ impl InboundGroupSession {
forwarding_chains.push(sender_key.to_owned()); forwarding_chains.push(sender_key.to_owned());
let mut sender_claimed_key = BTreeMap::new(); let mut sender_claimed_key = BTreeMap::new();
sender_claimed_key.insert( sender_claimed_key
DeviceKeyAlgorithm::Ed25519, .insert(DeviceKeyAlgorithm::Ed25519, content.sender_claimed_ed25519_key.to_owned());
content.sender_claimed_ed25519_key.to_owned(),
);
Ok(InboundGroupSession { Ok(InboundGroupSession {
inner: Mutex::new(session).into(), inner: Mutex::new(session).into(),
@ -217,11 +215,7 @@ impl InboundGroupSession {
let message_index = std::cmp::max(self.first_known_index(), message_index); let message_index = std::cmp::max(self.first_known_index(), message_index);
let session_key = ExportedGroupSessionKey( let session_key = ExportedGroupSessionKey(
self.inner self.inner.lock().await.export(message_index).expect("Can't export session"),
.lock()
.await
.export(message_index)
.expect("Can't export session"),
); );
ExportedRoomKey { ExportedRoomKey {
@ -314,9 +308,7 @@ impl InboundGroupSession {
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?; let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?; let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value let decrypted_object = decrypted_value.as_object_mut().ok_or(EventError::NotAnObject)?;
.as_object_mut()
.ok_or(EventError::NotAnObject)?;
// TODO better number conversion here. // TODO better number conversion here.
let server_ts = event let server_ts = event
@ -335,10 +327,8 @@ impl InboundGroupSession {
serde_json::to_value(&event.unsigned).unwrap_or_default(), serde_json::to_value(&event.unsigned).unwrap_or_default(),
); );
if let Some(decrypted_content) = decrypted_object if let Some(decrypted_content) =
.get_mut("content") decrypted_object.get_mut("content").map(|c| c.as_object_mut()).flatten()
.map(|c| c.as_object_mut())
.flatten()
{ {
if !decrypted_content.contains_key("m.relates_to") { if !decrypted_content.contains_key("m.relates_to") {
if let Some(relation) = &event.content.relates_to { if let Some(relation) = &event.content.relates_to {
@ -350,19 +340,14 @@ impl InboundGroupSession {
} }
} }
Ok(( Ok((serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?, message_index))
serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?,
message_index,
))
} }
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
impl fmt::Debug for InboundGroupSession { impl fmt::Debug for InboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InboundGroupSession") f.debug_struct("InboundGroupSession").field("session_id", &self.session_id()).finish()
.field("session_id", &self.session_id())
.finish()
} }
} }

View File

@ -108,10 +108,8 @@ impl From<ForwardedRoomKeyToDeviceEventContent> for ExportedRoomKey {
/// Convert the content of a forwarded room key into a exported room key. /// Convert the content of a forwarded room key into a exported room key.
fn from(forwarded_key: ForwardedRoomKeyToDeviceEventContent) -> Self { fn from(forwarded_key: ForwardedRoomKeyToDeviceEventContent) -> Self {
let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new(); let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
sender_claimed_keys.insert( sender_claimed_keys
DeviceKeyAlgorithm::Ed25519, .insert(DeviceKeyAlgorithm::Ed25519, forwarded_key.sender_claimed_ed25519_key);
forwarded_key.sender_claimed_ed25519_key,
);
Self { Self {
algorithm: forwarded_key.algorithm, algorithm: forwarded_key.algorithm,
@ -143,10 +141,7 @@ mod test {
#[tokio::test] #[tokio::test]
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
async fn expiration() { async fn expiration() {
let settings = EncryptionSettings { let settings = EncryptionSettings { rotation_period_msgs: 1, ..Default::default() };
rotation_period_msgs: 1,
..Default::default()
};
let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into()); let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account let (session, _) = account
@ -156,9 +151,9 @@ mod test {
assert!(!session.expired()); assert!(!session.expired());
let _ = session let _ = session
.encrypt(AnyMessageEventContent::RoomMessage( .encrypt(AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
MessageEventContent::text_plain("Test message"), "Test message",
)) )))
.await; .await;
assert!(session.expired()); assert!(session.expired());

View File

@ -96,12 +96,10 @@ impl EncryptionSettings {
/// Create new encryption settings using an `EncryptionEventContent` and a /// Create new encryption settings using an `EncryptionEventContent` and a
/// history visibility. /// history visibility.
pub fn new(content: EncryptionEventContent, history_visibility: HistoryVisibility) -> Self { pub fn new(content: EncryptionEventContent, history_visibility: HistoryVisibility) -> Self {
let rotation_period: Duration = content let rotation_period: Duration =
.rotation_period_ms content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into())); let rotation_period_msgs: u64 =
let rotation_period_msgs: u64 = content content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
.rotation_period_msgs
.map_or(ROTATION_MESSAGES, Into::into);
Self { Self {
algorithm: content.algorithm, algorithm: content.algorithm,
@ -180,8 +178,7 @@ impl OutboundGroupSession {
request: Arc<ToDeviceRequest>, request: Arc<ToDeviceRequest>,
message_index: u32, message_index: u32,
) { ) {
self.to_share_with_set self.to_share_with_set.insert(request_id, (request, message_index));
.insert(request_id, (request, message_index));
} }
/// This should be called if an the user wishes to rotate this session. /// This should be called if an the user wishes to rotate this session.
@ -219,10 +216,7 @@ impl OutboundGroupSession {
}); });
user_pairs.for_each(|(u, d)| { user_pairs.for_each(|(u, d)| {
self.shared_with_set self.shared_with_set.entry(u).or_insert_with(DashMap::new).extend(d);
.entry(u)
.or_insert_with(DashMap::new)
.extend(d);
}); });
if self.to_share_with_set.is_empty() { if self.to_share_with_set.is_empty() {
@ -235,11 +229,8 @@ impl OutboundGroupSession {
self.mark_as_shared(); self.mark_as_shared();
} }
} else { } else {
let request_ids: Vec<String> = self let request_ids: Vec<String> =
.to_share_with_set self.to_share_with_set.iter().map(|e| e.key().to_string()).collect();
.iter()
.map(|e| e.key().to_string())
.collect();
error!( error!(
all_request_ids = ?request_ids, all_request_ids = ?request_ids,
@ -290,11 +281,7 @@ impl OutboundGroupSession {
let relates_to: Option<Relation> = json_content let relates_to: Option<Relation> = json_content
.get("content") .get("content")
.map(|c| { .map(|c| c.get("m.relates_to").cloned().map(|r| serde_json::from_value(r).ok()))
c.get("m.relates_to")
.cloned()
.map(|r| serde_json::from_value(r).ok())
})
.flatten() .flatten()
.flatten(); .flatten();
@ -437,10 +424,7 @@ impl OutboundGroupSession {
/// Get the list of requests that need to be sent out for this session to be /// Get the list of requests that need to be sent out for this session to be
/// marked as shared. /// marked as shared.
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> { pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
self.to_share_with_set self.to_share_with_set.iter().map(|i| i.value().0.clone()).collect()
.iter()
.map(|i| i.value().0.clone())
.collect()
} }
/// Get the list of request ids this session is waiting for to be sent out. /// Get the list of request ids this session is waiting for to be sent out.
@ -455,11 +439,11 @@ impl OutboundGroupSession {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `device_id` - The device id of the device that created this session. Put differently, our /// * `device_id` - The device id of the device that created this session.
/// own device id. /// Put differently, our own device id.
/// ///
/// * `identity_keys` - The identity keys of the device that created this session, our own /// * `identity_keys` - The identity keys of the device that created this
/// identity keys. /// session, our own identity keys.
/// ///
/// * `pickle` - The pickled version of the `OutboundGroupSession`. /// * `pickle` - The pickled version of the `OutboundGroupSession`.
/// ///
@ -501,7 +485,8 @@ impl OutboundGroupSession {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `pickle_mode` - The mode that should be used to pickle the group session, /// * `pickle_mode` - The mode that should be used to pickle the group
/// session,
/// either an unencrypted mode or an encrypted using passphrase. /// either an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickling_mode: PicklingMode) -> PickledOutboundGroupSession { pub async fn pickle(&self, pickling_mode: PicklingMode) -> PickledOutboundGroupSession {
let pickle: OutboundGroupSessionPickle = let pickle: OutboundGroupSessionPickle =
@ -522,10 +507,7 @@ impl OutboundGroupSession {
( (
u.key().clone(), u.key().clone(),
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
u.value() u.value().iter().map(|d| (d.key().clone(), *d.value())).collect(),
.iter()
.map(|d| (d.key().clone(), *d.value()))
.collect(),
) )
}) })
.collect(), .collect(),
@ -572,10 +554,7 @@ pub struct PickledOutboundGroupSession {
/// The room id this session is used for. /// The room id this session is used for.
pub room_id: Arc<RoomId>, pub room_id: Arc<RoomId>,
/// The timestamp when this session was created. /// The timestamp when this session was created.
#[serde( #[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
pub creation_time: Instant, pub creation_time: Instant,
/// The number of messages this session has already encrypted. /// The number of messages this session has already encrypted.
pub message_count: u64, pub message_count: u64,

View File

@ -91,21 +91,12 @@ pub(crate) mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await; bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob let one_time_key =
.one_time_keys() bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new()); let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned(); let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice let session =
.create_outbound_session_helper(&sender_key, &one_time_key) alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
.await
.unwrap();
(alice, session) (alice, session)
} }
@ -121,10 +112,7 @@ pub(crate) mod test {
assert_ne!(identity_keys.keys().len(), 0); assert_ne!(identity_keys.keys().len(), 0);
assert_ne!(identity_keys.iter().len(), 0); assert_ne!(identity_keys.iter().len(), 0);
assert!(identity_keys.contains_key("ed25519")); assert!(identity_keys.contains_key("ed25519"));
assert_eq!( assert_eq!(identity_keys.ed25519(), identity_keys.get("ed25519").unwrap());
identity_keys.ed25519(),
identity_keys.get("ed25519").unwrap()
);
assert!(!identity_keys.curve25519().is_empty()); assert!(!identity_keys.curve25519().is_empty());
account.mark_as_shared(); account.mark_as_shared();
@ -148,10 +136,7 @@ pub(crate) mod test {
assert_ne!(one_time_keys.iter().len(), 0); assert_ne!(one_time_keys.iter().len(), 0);
assert!(one_time_keys.contains_key("curve25519")); assert!(one_time_keys.contains_key("curve25519"));
assert_eq!(one_time_keys.curve25519().keys().len(), 10); assert_eq!(one_time_keys.curve25519().keys().len(), 10);
assert_eq!( assert_eq!(one_time_keys.curve25519(), one_time_keys.get("curve25519").unwrap());
one_time_keys.curve25519(),
one_time_keys.get("curve25519").unwrap()
);
account.mark_keys_as_published().await; account.mark_keys_as_published().await;
let one_time_keys = account.one_time_keys().await; let one_time_keys = account.one_time_keys().await;
@ -167,13 +152,7 @@ pub(crate) mod test {
let one_time_keys = alice.one_time_keys().await; let one_time_keys = alice.one_time_keys().await;
alice.mark_keys_as_published().await; alice.mark_keys_as_published().await;
let one_time_key = one_time_keys let one_time_key = one_time_keys.curve25519().iter().next().unwrap().1.to_owned();
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new()); let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
@ -197,10 +176,7 @@ pub(crate) mod test {
.await .await
.unwrap(); .unwrap();
assert!(alice_session assert!(alice_session.matches(bob_keys.curve25519(), prekey_message).await.unwrap());
.matches(bob_keys.curve25519(), prekey_message)
.await
.unwrap());
assert_eq!(bob_session.session_id(), alice_session.session_id()); assert_eq!(bob_session.session_id(), alice_session.session_id());
@ -213,10 +189,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = alice let (outbound, _) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
assert_eq!(0, outbound.message_index().await); assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared()); assert!(!outbound.shared());
@ -239,10 +212,7 @@ pub(crate) mod test {
let plaintext = "This is a secret to everybody".to_owned(); let plaintext = "This is a secret to everybody".to_owned();
let ciphertext = outbound.encrypt_helper(plaintext.clone()).await; let ciphertext = outbound.encrypt_helper(plaintext.clone()).await;
assert_eq!( assert_eq!(plaintext, inbound.decrypt_helper(ciphertext).await.unwrap().0);
plaintext,
inbound.decrypt_helper(ciphertext).await.unwrap().0
);
} }
#[tokio::test] #[tokio::test]
@ -250,10 +220,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (_, inbound) = alice let (_, inbound) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let export = inbound.export().await; let export = inbound.export().await;
let export: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let export: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();

View File

@ -118,10 +118,8 @@ impl Session {
.get_key(DeviceKeyAlgorithm::Ed25519) .get_key(DeviceKeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?; .ok_or(EventError::MissingSigningKey)?;
let relates_to = content let relates_to =
.get("m.relates_to") content.get("m.relates_to").cloned().and_then(|v| serde_json::from_value(v).ok());
.cloned()
.and_then(|v| serde_json::from_value(v).ok());
let payload = json!({ let payload = json!({
"sender": self.user_id.as_str(), "sender": self.user_id.as_str(),
@ -171,10 +169,7 @@ impl Session {
their_identity_key: &str, their_identity_key: &str,
message: PreKeyMessage, message: PreKeyMessage,
) -> Result<bool, OlmSessionError> { ) -> Result<bool, OlmSessionError> {
self.inner self.inner.lock().await.matches_inbound_session_from(their_identity_key, message)
.lock()
.await
.matches_inbound_session_from(their_identity_key, message)
} }
/// Returns the unique identifier for this session. /// Returns the unique identifier for this session.
@ -256,16 +251,10 @@ pub struct PickledSession {
/// The curve25519 key of the other user that we share this session with. /// The curve25519 key of the other user that we share this session with.
pub sender_key: String, pub sender_key: String,
/// The relative time elapsed since the session was created. /// The relative time elapsed since the session was created.
#[serde( #[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
pub creation_time: Instant, pub creation_time: Instant,
/// The relative time elapsed since the session was last used. /// The relative time elapsed since the session was last used.
#[serde( #[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
pub last_use_time: Instant, pub last_use_time: Instant,
} }

View File

@ -185,10 +185,7 @@ impl PrivateCrossSigningIdentity {
signed_keys signed_keys
.entry((&*self.user_id).to_owned()) .entry((&*self.user_id).to_owned())
.or_insert_with(BTreeMap::new) .or_insert_with(BTreeMap::new)
.insert( .insert(device_keys.device_id.to_string(), serde_json::to_value(device_keys)?);
device_keys.device_id.to_string(),
serde_json::to_value(device_keys)?,
);
Ok(SignatureUploadRequest::new(signed_keys)) Ok(SignatureUploadRequest::new(signed_keys))
} }
@ -228,10 +225,7 @@ impl PrivateCrossSigningIdentity {
signature, signature,
); );
let master = MasterSigning { let master = MasterSigning { inner: master, public_key: public_key.into() };
inner: master,
public_key: public_key.into(),
};
let identity = Self::new_helper(account.user_id(), master).await; let identity = Self::new_helper(account.user_id(), master).await;
let signature_request = identity let signature_request = identity
@ -249,20 +243,14 @@ impl PrivateCrossSigningIdentity {
let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning); let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning);
master.sign_subkey(&mut public_key).await; master.sign_subkey(&mut public_key).await;
let user = UserSigning { let user = UserSigning { inner: user, public_key: public_key.into() };
inner: user,
public_key: public_key.into(),
};
let self_signing = Signing::new(); let self_signing = Signing::new();
let mut public_key = let mut public_key =
self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning); self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning);
master.sign_subkey(&mut public_key).await; master.sign_subkey(&mut public_key).await;
let self_signing = SelfSigning { let self_signing = SelfSigning { inner: self_signing, public_key: public_key.into() };
inner: self_signing,
public_key: public_key.into(),
};
Self { Self {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
@ -280,10 +268,7 @@ impl PrivateCrossSigningIdentity {
let master = Signing::new(); let master = Signing::new();
let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master); let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master);
let master = MasterSigning { let master = MasterSigning { inner: master, public_key: public_key.into() };
inner: master,
public_key: public_key.into(),
};
Self::new_helper(&user_id, master).await Self::new_helper(&user_id, master).await
} }
@ -333,11 +318,7 @@ impl PrivateCrossSigningIdentity {
None None
}; };
let pickle = PickledSignings { let pickle = PickledSignings { master_key, user_signing_key, self_signing_key };
master_key,
user_signing_key,
self_signing_key,
};
let pickle = serde_json::to_string(&pickle)?; let pickle = serde_json::to_string(&pickle)?;
@ -389,35 +370,16 @@ impl PrivateCrossSigningIdentity {
/// Get the upload request that is needed to share the public keys of this /// Get the upload request that is needed to share the public keys of this
/// identity. /// identity.
pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest { pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest {
let master_key = self let master_key =
.master_key self.master_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let user_signing_key = self let user_signing_key =
.user_signing_key self.user_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let self_signing_key = self let self_signing_key =
.self_signing_key self.self_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
UploadSigningKeysRequest { UploadSigningKeysRequest { master_key, self_signing_key, user_signing_key }
master_key,
self_signing_key,
user_signing_key,
}
} }
} }
@ -480,28 +442,12 @@ mod test {
assert!(master_key assert!(master_key
.public_key .public_key
.verify_subkey( .verify_subkey(&identity.self_signing_key.lock().await.as_ref().unwrap().public_key,)
&identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.is_ok()); .is_ok());
assert!(master_key assert!(master_key
.public_key .public_key
.verify_subkey( .verify_subkey(&identity.user_signing_key.lock().await.as_ref().unwrap().public_key,)
&identity
.user_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.is_ok()); .is_ok());
} }
@ -511,15 +457,11 @@ mod test {
let pickled = identity.pickle(pickle_key()).await.unwrap(); let pickled = identity.pickle(pickle_key()).await.unwrap();
let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()) let unpickled =
.await PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()).await.unwrap();
.unwrap();
assert_eq!(identity.user_id, unpickled.user_id); assert_eq!(identity.user_id, unpickled.user_id);
assert_eq!( assert_eq!(&*identity.master_key.lock().await, &*unpickled.master_key.lock().await);
&*identity.master_key.lock().await,
&*unpickled.master_key.lock().await
);
assert_eq!( assert_eq!(
&*identity.user_signing_key.lock().await, &*identity.user_signing_key.lock().await,
&*unpickled.user_signing_key.lock().await &*unpickled.user_signing_key.lock().await
@ -590,9 +532,6 @@ mod test {
bob_public.master_key = master.into(); bob_public.master_key = master.into();
user_signing user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap();
.public_key
.verify_master_key(bob_public.master_key())
.unwrap();
} }
} }

View File

@ -68,9 +68,7 @@ pub struct Signing {
impl std::fmt::Debug for Signing { impl std::fmt::Debug for Signing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Signing") f.debug_struct("Signing").field("public_key", &self.public_key.as_str()).finish()
.field("public_key", &self.public_key.as_str())
.finish()
} }
} }
@ -151,10 +149,7 @@ impl MasterSigning {
) -> Result<Self, SigningError> { ) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self { Ok(Self { inner, public_key: pickle.public_key.into() })
inner,
public_key: pickle.public_key.into(),
})
} }
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) { pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
@ -195,10 +190,7 @@ impl UserSigning {
user: &UserIdentity, user: &UserIdentity,
) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> { ) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> {
let user_master: &CrossSigningKey = user.master_key().as_ref(); let user_master: &CrossSigningKey = user.master_key().as_ref();
let signature = self let signature = self.inner.sign_json(serde_json::to_value(user_master)?).await?;
.inner
.sign_json(serde_json::to_value(user_master)?)
.await?;
let mut signatures = BTreeMap::new(); let mut signatures = BTreeMap::new();
@ -223,10 +215,7 @@ impl UserSigning {
) -> Result<Self, SigningError> { ) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self { Ok(Self { inner, public_key: pickle.public_key.into() })
inner,
public_key: pickle.public_key.into(),
})
} }
} }
@ -274,10 +263,7 @@ impl SelfSigning {
) -> Result<Self, SigningError> { ) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self { Ok(Self { inner, public_key: pickle.public_key.into() })
inner,
public_key: pickle.public_key.into(),
})
} }
} }
@ -348,17 +334,12 @@ impl Signing {
getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object"); getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object");
let nonce = GenericArray::from_slice(nonce.as_slice()); let nonce = GenericArray::from_slice(nonce.as_slice());
let ciphertext = cipher let ciphertext =
.encrypt(nonce, self.seed.as_slice()) cipher.encrypt(nonce, self.seed.as_slice()).expect("Can't encrypt signing pickle");
.expect("Can't encrypt signing pickle");
let ciphertext = encode(ciphertext); let ciphertext = encode(ciphertext);
let pickle = InnerPickle { let pickle = InnerPickle { version: 1, nonce: encode(nonce.as_slice()), ciphertext };
version: 1,
nonce: encode(nonce.as_slice()),
ciphertext,
};
PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing")) PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
} }
@ -371,10 +352,7 @@ impl Signing {
let mut keys = BTreeMap::new(); let mut keys = BTreeMap::new();
keys.insert( keys.insert(
DeviceKeyId::from_parts( DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.public_key().as_str().into())
DeviceKeyAlgorithm::Ed25519,
self.public_key().as_str().into(),
)
.to_string(), .to_string(),
self.public_key().to_string(), self.public_key().to_string(),
); );

View File

@ -29,9 +29,7 @@ pub(crate) struct Utility {
impl Utility { impl Utility {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self { inner: OlmUtility::new() }
inner: OlmUtility::new(),
}
} }
/// Verify a signed JSON object. /// Verify a signed JSON object.
@ -66,29 +64,20 @@ impl Utility {
let unsigned = json_object.remove("unsigned"); let unsigned = json_object.remove("unsigned");
let signatures = json_object.remove("signatures"); let signatures = json_object.remove("signatures");
let canonical_json: CanonicalJsonValue = json let canonical_json: CanonicalJsonValue =
.clone() json.clone().try_into().map_err(|_| SignatureError::NotAnObject)?;
.try_into()
.map_err(|_| SignatureError::NotAnObject)?;
let canonical_json: String = canonical_json.to_string(); let canonical_json: String = canonical_json.to_string();
let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?; let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?;
let signature_object = signatures let signature_object = signatures.as_object().ok_or(SignatureError::NoSignatureFound)?;
.as_object() let signature =
.ok_or(SignatureError::NoSignatureFound)?; signature_object.get(user_id.as_str()).ok_or(SignatureError::NoSignatureFound)?;
let signature = signature_object let signature =
.get(user_id.as_str()) signature.get(key_id.to_string()).ok_or(SignatureError::NoSignatureFound)?;
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature
.get(key_id.to_string())
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature.as_str().ok_or(SignatureError::NoSignatureFound)?; let signature = signature.as_str().ok_or(SignatureError::NoSignatureFound)?;
let ret = match self let ret = match self.inner.ed25519_verify(signing_key, &canonical_json, signature) {
.inner
.ed25519_verify(signing_key, &canonical_json, signature)
{
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => Err(SignatureError::VerificationError), Err(_) => Err(SignatureError::VerificationError),
}; };

View File

@ -108,11 +108,7 @@ pub struct KeysQueryRequest {
impl KeysQueryRequest { impl KeysQueryRequest {
pub(crate) fn new(device_keys: BTreeMap<UserId, Vec<DeviceIdBox>>) -> Self { pub(crate) fn new(device_keys: BTreeMap<UserId, Vec<DeviceIdBox>>) -> Self {
Self { Self { timeout: None, device_keys, token: None }
timeout: None,
device_keys,
token: None,
}
} }
} }
@ -176,19 +172,13 @@ impl From<SignatureUploadRequest> for OutgoingRequests {
impl From<OutgoingVerificationRequest> for OutgoingRequest { impl From<OutgoingVerificationRequest> for OutgoingRequest {
fn from(r: OutgoingVerificationRequest) -> Self { fn from(r: OutgoingVerificationRequest) -> Self {
Self { Self { request_id: r.request_id(), request: Arc::new(r.into()) }
request_id: r.request_id(),
request: Arc::new(r.into()),
}
} }
} }
impl From<SignatureUploadRequest> for OutgoingRequest { impl From<SignatureUploadRequest> for OutgoingRequest {
fn from(r: SignatureUploadRequest) -> Self { fn from(r: SignatureUploadRequest) -> Self {
Self { Self { request_id: Uuid::new_v4(), request: Arc::new(r.into()) }
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
}
} }
} }

View File

@ -104,10 +104,7 @@ impl GroupSessionCache {
room_id: &RoomId, room_id: &RoomId,
session_id: &str, session_id: &str,
) -> StoreResult<Option<OutboundGroupSession>> { ) -> StoreResult<Option<OutboundGroupSession>> {
Ok(self Ok(self.get_or_load(room_id).await?.filter(|o| session_id == o.session_id()))
.get_or_load(room_id)
.await?
.filter(|o| session_id == o.session_id()))
} }
} }
@ -126,11 +123,7 @@ impl GroupSessionManager {
const MAX_TO_DEVICE_MESSAGES: usize = 250; const MAX_TO_DEVICE_MESSAGES: usize = 250;
pub(crate) fn new(account: Account, store: Store) -> Self { pub(crate) fn new(account: Account, store: Store) -> Self {
Self { Self { account, store: store.clone(), sessions: GroupSessionCache::new(store) }
account,
store: store.clone(),
sessions: GroupSessionCache::new(store),
}
} }
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> { pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
@ -230,9 +223,7 @@ impl GroupSessionManager {
Ok((s, None)) Ok((s, None))
} }
} else { } else {
self.create_outbound_group_session(room_id, settings) self.create_outbound_group_session(room_id, settings).await.map(|(o, i)| (o, i.into()))
.await
.map(|(o, i)| (o, i.into()))
} }
} }
@ -252,10 +243,7 @@ impl GroupSessionManager {
let used_session = match encrypted { let used_session = match encrypted {
Ok((session, encrypted)) => { Ok((session, encrypted)) => {
message message.entry(device.user_id().clone()).or_insert_with(BTreeMap::new).insert(
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()), DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?, serde_json::value::to_raw_value(&encrypted)?,
); );
@ -270,10 +258,8 @@ impl GroupSessionManager {
Ok((used_session, message)) Ok((used_session, message))
}; };
let tasks: Vec<_> = devices let tasks: Vec<_> =
.iter() devices.iter().map(|d| spawn(encrypt(d.clone(), content.clone()))).collect();
.map(|d| spawn(encrypt(d.clone(), content.clone())))
.collect();
let results = join_all(tasks).await; let results = join_all(tasks).await;
@ -285,20 +271,14 @@ impl GroupSessionManager {
} }
for (user, device_messages) in message.into_iter() { for (user, device_messages) in message.into_iter() {
messages messages.entry(user).or_insert_with(BTreeMap::new).extend(device_messages);
.entry(user)
.or_insert_with(BTreeMap::new)
.extend(device_messages);
} }
} }
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let request = ToDeviceRequest { let request =
event_type: EventType::RoomEncrypted, ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages };
txn_id: id,
messages,
};
trace!( trace!(
recipient_count = request.message_count(), recipient_count = request.message_count(),
@ -330,20 +310,14 @@ impl GroupSessionManager {
"Calculating group session recipients" "Calculating group session recipients"
); );
let users_shared_with: HashSet<UserId> = outbound let users_shared_with: HashSet<UserId> =
.shared_with_set outbound.shared_with_set.iter().map(|k| k.key().clone()).collect();
.iter()
.map(|k| k.key().clone())
.collect();
let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect(); let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect();
// A user left if a user is missing from the set of users that should // A user left if a user is missing from the set of users that should
// get the session but is in the set of users that received the session. // get the session but is in the set of users that received the session.
let user_left = !users_shared_with let user_left = !users_shared_with.difference(&users).collect::<HashSet<_>>().is_empty();
.difference(&users)
.collect::<HashSet<_>>()
.is_empty();
let visibility_changed = outbound.settings().history_visibility != history_visibility; let visibility_changed = outbound.settings().history_visibility != history_visibility;
@ -358,10 +332,8 @@ impl GroupSessionManager {
for user_id in users { for user_id in users {
let user_devices = self.store.get_user_devices(&user_id).await?; let user_devices = self.store.get_user_devices(&user_id).await?;
let non_blacklisted_devices: Vec<Device> = user_devices let non_blacklisted_devices: Vec<Device> =
.devices() user_devices.devices().filter(|d| !d.is_blacklisted()).collect();
.filter(|d| !d.is_blacklisted())
.collect();
// If we haven't already concluded that the session should be // If we haven't already concluded that the session should be
// rotated for other reasons, we also need to check whether any // rotated for other reasons, we also need to check whether any
@ -369,10 +341,8 @@ impl GroupSessionManager {
// meantime. If so, we should also rotate the session. // meantime. If so, we should also rotate the session.
if !should_rotate { if !should_rotate {
// Device IDs that should receive this session // Device IDs that should receive this session
let non_blacklisted_device_ids: HashSet<&DeviceId> = non_blacklisted_devices let non_blacklisted_device_ids: HashSet<&DeviceId> =
.iter() non_blacklisted_devices.iter().map(|d| d.device_id()).collect();
.map(|d| d.device_id())
.collect();
if let Some(shared) = outbound.shared_with_set.get(user_id) { if let Some(shared) = outbound.shared_with_set.get(user_id) {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
@ -388,9 +358,8 @@ impl GroupSessionManager {
// //
// represents newly deleted or blacklisted devices. If this // represents newly deleted or blacklisted devices. If this
// set is non-empty, we must rotate. // set is non-empty, we must rotate.
let newly_deleted_or_blacklisted = shared let newly_deleted_or_blacklisted =
.difference(&non_blacklisted_device_ids) shared.difference(&non_blacklisted_device_ids).collect::<HashSet<_>>();
.collect::<HashSet<_>>();
if !newly_deleted_or_blacklisted.is_empty() { if !newly_deleted_or_blacklisted.is_empty() {
should_rotate = true; should_rotate = true;
@ -398,10 +367,7 @@ impl GroupSessionManager {
}; };
} }
devices devices.entry(user_id.clone()).or_insert_with(Vec::new).extend(non_blacklisted_devices);
.entry(user_id.clone())
.or_insert_with(Vec::new)
.extend(non_blacklisted_devices);
} }
debug!( debug!(
@ -461,25 +427,22 @@ impl GroupSessionManager {
let history_visibility = encryption_settings.history_visibility.clone(); let history_visibility = encryption_settings.history_visibility.clone();
let mut changes = Changes::default(); let mut changes = Changes::default();
let (outbound, inbound) = self let (outbound, inbound) =
.get_or_create_outbound_session(room_id, encryption_settings.clone()) self.get_or_create_outbound_session(room_id, encryption_settings.clone()).await?;
.await?;
if let Some(inbound) = inbound { if let Some(inbound) = inbound {
changes.outbound_group_sessions.push(outbound.clone()); changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound); changes.inbound_group_sessions.push(inbound);
} }
let (should_rotate, devices) = self let (should_rotate, devices) =
.collect_session_recipients(users, history_visibility, &outbound) self.collect_session_recipients(users, history_visibility, &outbound).await?;
.await?;
let outbound = if should_rotate { let outbound = if should_rotate {
let old_session_id = outbound.session_id(); let old_session_id = outbound.session_id();
let (outbound, inbound) = self let (outbound, inbound) =
.create_outbound_group_session(room_id, encryption_settings) self.create_outbound_group_session(room_id, encryption_settings).await?;
.await?;
changes.outbound_group_sessions.push(outbound.clone()); changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound); changes.inbound_group_sessions.push(inbound);
@ -514,9 +477,7 @@ impl GroupSessionManager {
if !devices.is_empty() { if !devices.is_empty() {
let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| { let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| {
acc.entry(d.user_id()) acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
.or_insert_with(BTreeSet::new)
.insert(d.device_id());
acc acc
}); });
@ -625,14 +586,8 @@ mod test {
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
machine machine.mark_request_as_sent(&uuid, &keys_query).await.unwrap();
.mark_request_as_sent(&uuid, &keys_query) machine.mark_request_as_sent(&uuid, &keys_claim).await.unwrap();
.await
.unwrap();
machine
.mark_request_as_sent(&uuid, &keys_claim)
.await
.unwrap();
machine machine
} }
@ -646,11 +601,7 @@ mod test {
let users: Vec<_> = keys_claim.one_time_keys.keys().collect(); let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
let requests = machine let requests = machine
.share_group_session( .share_group_session(&room_id, users.clone().into_iter(), EncryptionSettings::default())
&room_id,
users.clone().into_iter(),
EncryptionSettings::default(),
)
.await .await
.unwrap(); .unwrap();

View File

@ -77,11 +77,7 @@ impl SessionManager {
} }
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> { pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
if let Some(device) = self if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
.store
.get_device_from_curve_key(sender, curve_key)
.await?
{
let sessions = device.get_sessions().await?; let sessions = device.get_sessions().await?;
if let Some(sessions) = sessions { if let Some(sessions) = sessions {
@ -120,22 +116,13 @@ impl SessionManager {
/// ///
/// If the device was wedged this will queue up a dummy to-device message. /// If the device was wedged this will queue up a dummy to-device message.
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> { async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self if self.wedged_devices.get(user_id).map(|d| d.remove(device_id)).flatten().is_some() {
.wedged_devices
.get(user_id)
.map(|d| d.remove(device_id))
.flatten()
.is_some()
{
if let Some(device) = self.store.get_device(user_id, device_id).await? { if let Some(device) = self.store.get_device(user_id, device_id).await? {
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?; let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()), DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?, to_raw_value(&content)?,
); );
@ -347,9 +334,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id); let account = ReadOnlyAccount::new(&user_id, &device_id);
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
store.save_account(account.clone()).await.unwrap(); store.save_account(account.clone()).await.unwrap();
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty( let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id.clone())));
user_id.clone(),
)));
let verification = let verification =
VerificationMachine::new(account.clone(), identity.clone(), store.clone()); VerificationMachine::new(account.clone(), identity.clone(), store.clone());
@ -358,10 +343,7 @@ mod test {
let store = Store::new(user_id.clone(), identity, store, verification); let store = Store::new(user_id.clone(), identity, store, verification);
let account = Account { let account = Account { inner: account, store: store.clone() };
inner: account,
store: store.clone(),
};
let session_cache = GroupSessionCache::new(store.clone()); let session_cache = GroupSessionCache::new(store.clone());
@ -405,10 +387,7 @@ mod test {
let response = KeyClaimResponse::new(one_time_keys); let response = KeyClaimResponse::new(one_time_keys);
manager manager.receive_keys_claim_response(&response).await.unwrap();
.receive_keys_claim_response(&response)
.await
.unwrap();
assert!(manager assert!(manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter()) .get_missing_sessions(&mut [bob.user_id().clone()].iter())
@ -434,11 +413,7 @@ mod test {
let bob_device = ReadOnlyDevice::from_account(&bob).await; let bob_device = ReadOnlyDevice::from_account(&bob).await;
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601)); session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601));
manager manager.store.save_devices(&[bob_device.clone()]).await.unwrap();
.store
.save_devices(&[bob_device.clone()])
.await
.unwrap();
manager.store.save_sessions(&[session]).await.unwrap(); manager.store.save_sessions(&[session]).await.unwrap();
assert!(manager assert!(manager
@ -451,10 +426,7 @@ mod test {
assert!(!manager.users_for_key_claim.contains_key(bob.user_id())); assert!(!manager.users_for_key_claim.contains_key(bob.user_id()));
assert!(!manager.is_device_wedged(&bob_device)); assert!(!manager.is_device_wedged(&bob_device));
manager manager.mark_device_as_wedged(bob_device.user_id(), &curve_key).await.unwrap();
.mark_device_as_wedged(bob_device.user_id(), &curve_key)
.await
.unwrap();
assert!(manager.is_device_wedged(&bob_device)); assert!(manager.is_device_wedged(&bob_device));
assert!(manager.users_for_key_claim.contains_key(bob.user_id())); assert!(manager.users_for_key_claim.contains_key(bob.user_id()));
@ -480,10 +452,7 @@ mod test {
assert!(manager.outgoing_to_device_requests.is_empty()); assert!(manager.outgoing_to_device_requests.is_empty());
manager manager.receive_keys_claim_response(&response).await.unwrap();
.receive_keys_claim_response(&response)
.await
.unwrap();
assert!(!manager.is_device_wedged(&bob_device)); assert!(!manager.is_device_wedged(&bob_device));
assert!(manager assert!(manager

View File

@ -39,9 +39,7 @@ pub struct SessionStore {
impl SessionStore { impl SessionStore {
/// Create a new empty Session store. /// Create a new empty Session store.
pub fn new() -> Self { pub fn new() -> Self {
SessionStore { SessionStore { entries: Arc::new(DashMap::new()) }
entries: Arc::new(DashMap::new()),
}
} }
/// Add a session to the store. /// Add a session to the store.
@ -72,8 +70,7 @@ impl SessionStore {
/// Add a list of sessions belonging to the sender key. /// Add a list of sessions belonging to the sender key.
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) { pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
self.entries self.entries.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
} }
} }
@ -87,9 +84,7 @@ pub struct GroupSessionStore {
impl GroupSessionStore { impl GroupSessionStore {
/// Create a new empty store. /// Create a new empty store.
pub fn new() -> Self { pub fn new() -> Self {
GroupSessionStore { GroupSessionStore { entries: Arc::new(DashMap::new()) }
entries: Arc::new(DashMap::new()),
}
} }
/// Add an inbound group session to the store. /// Add an inbound group session to the store.
@ -148,9 +143,7 @@ pub struct DeviceStore {
impl DeviceStore { impl DeviceStore {
/// Create a new empty device store. /// Create a new empty device store.
pub fn new() -> Self { pub fn new() -> Self {
DeviceStore { DeviceStore { entries: Arc::new(DashMap::new()) }
entries: Arc::new(DashMap::new()),
}
} }
/// Add a device to the store. /// Add a device to the store.
@ -167,19 +160,15 @@ impl DeviceStore {
/// Get the device with the given device_id and belonging to the given user. /// Get the device with the given device_id and belonging to the given user.
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> { pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries self.entries.get(user_id).and_then(|m| m.get(device_id).map(|d| d.value().clone()))
.get(user_id)
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
} }
/// Remove the device with the given device_id and belonging to the given user. /// Remove the device with the given device_id and belonging to the given
/// user.
/// ///
/// Returns the device if it was removed, None if it wasn't in the store. /// Returns the device if it was removed, None if it wasn't in the store.
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> { pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries self.entries.get(user_id).and_then(|m| m.remove(device_id)).map(|(_, d)| d)
.get(user_id)
.and_then(|m| m.remove(device_id))
.map(|(_, d)| d)
} }
/// Get a read-only view over all devices of the given user. /// Get a read-only view over all devices of the given user.
@ -240,10 +229,8 @@ mod test {
let (account, _) = get_account_and_session().await; let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = account let (outbound, _) =
.create_group_session_pair_with_defaults(&room_id) account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.await
.unwrap();
assert_eq!(0, outbound.message_index().await); assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared()); assert!(!outbound.shared());
@ -262,9 +249,7 @@ mod test {
let store = GroupSessionStore::new(); let store = GroupSessionStore::new();
store.add(inbound.clone()); store.add(inbound.clone());
let loaded_session = store let loaded_session = store.get(&room_id, "test_key", outbound.session_id()).unwrap();
.get(&room_id, "test_key", outbound.session_id())
.unwrap();
assert_eq!(inbound, loaded_session); assert_eq!(inbound, loaded_session);
} }

View File

@ -37,10 +37,7 @@ use crate::{
}; };
fn encode_key_info(info: &RequestedKeyInfo) -> String { fn encode_key_info(info: &RequestedKeyInfo) -> String {
format!( format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
"{}{}{}{}",
info.room_id, info.sender_key, info.algorithm, info.session_id
)
} }
/// An in-memory only store that will forget all the E2EE key once it's dropped. /// An in-memory only store that will forget all the E2EE key once it's dropped.
@ -121,22 +118,14 @@ impl CryptoStore for MemoryStore {
async fn save_changes(&self, mut changes: Changes) -> Result<()> { async fn save_changes(&self, mut changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions).await; self.save_sessions(changes.sessions).await;
self.save_inbound_group_sessions(changes.inbound_group_sessions) self.save_inbound_group_sessions(changes.inbound_group_sessions).await;
.await;
self.save_devices(changes.devices.new).await; self.save_devices(changes.devices.new).await;
self.save_devices(changes.devices.changed).await; self.save_devices(changes.devices.changed).await;
self.delete_devices(changes.devices.deleted).await; self.delete_devices(changes.devices.deleted).await;
for identity in changes for identity in changes.identities.new.drain(..).chain(changes.identities.changed) {
.identities let _ = self.identities.insert(identity.user_id().to_owned(), identity.clone());
.new
.drain(..)
.chain(changes.identities.changed)
{
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
} }
for hash in changes.message_hashes { for hash in changes.message_hashes {
@ -167,9 +156,7 @@ impl CryptoStore for MemoryStore {
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<InboundGroupSession>> { ) -> Result<Option<InboundGroupSession>> {
Ok(self Ok(self.inbound_group_sessions.get(room_id, sender_key, session_id))
.inbound_group_sessions
.get(room_id, sender_key, session_id))
} }
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> { async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
@ -250,10 +237,7 @@ impl CryptoStore for MemoryStore {
&self, &self,
request_id: Uuid, request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
Ok(self Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone()))
.outgoing_key_requests
.get(&request_id)
.map(|r| r.clone()))
} }
async fn get_key_request_by_info( async fn get_key_request_by_info(
@ -278,9 +262,7 @@ impl CryptoStore for MemoryStore {
} }
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
self.outgoing_key_requests self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| {
.remove(&request_id)
.and_then(|(_, i)| {
let key_info_string = encode_key_info(&i.info); let key_info_string = encode_key_info(&i.info);
self.key_requests_by_info.remove(&key_info_string) self.key_requests_by_info.remove(&key_info_string)
}); });
@ -309,11 +291,7 @@ mod test {
store.save_sessions(vec![session.clone()]).await; store.save_sessions(vec![session.clone()]).await;
let sessions = store let sessions = store.get_sessions(&session.sender_key).await.unwrap().unwrap();
.get_sessions(&session.sender_key)
.await
.unwrap()
.unwrap();
let sessions = sessions.lock().await; let sessions = sessions.lock().await;
let loaded_session = &sessions[0]; let loaded_session = &sessions[0];
@ -326,10 +304,8 @@ mod test {
let (account, _) = get_account_and_session().await; let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = account let (outbound, _) =
.create_group_session_pair_with_defaults(&room_id) account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.await
.unwrap();
let inbound = InboundGroupSession::new( let inbound = InboundGroupSession::new(
"test_key", "test_key",
"test_key", "test_key",
@ -340,9 +316,7 @@ mod test {
.unwrap(); .unwrap();
let store = MemoryStore::new(); let store = MemoryStore::new();
let _ = store let _ = store.save_inbound_group_sessions(vec![inbound.clone()]).await;
.save_inbound_group_sessions(vec![inbound.clone()])
.await;
let loaded_session = store let loaded_session = store
.get_inbound_group_session(&room_id, "test_key", outbound.session_id()) .get_inbound_group_session(&room_id, "test_key", outbound.session_id())
@ -359,11 +333,8 @@ mod test {
store.save_devices(vec![device.clone()]).await; store.save_devices(vec![device.clone()]).await;
let loaded_device = store let loaded_device =
.get_device(device.user_id(), device.device_id()) store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(device, loaded_device); assert_eq!(device, loaded_device);
@ -377,11 +348,7 @@ mod test {
assert_eq!(&device, loaded_device); assert_eq!(&device, loaded_device);
store.delete_devices(vec![device.clone()]).await; store.delete_devices(vec![device.clone()]).await;
assert!(store assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.is_none());
} }
#[tokio::test] #[tokio::test]
@ -389,14 +356,8 @@ mod test {
let device = get_device(); let device = get_device();
let store = MemoryStore::new(); let store = MemoryStore::new();
assert!(store assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
.update_tracked_user(device.user_id(), false) assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.is_user_tracked(device.user_id())); assert!(store.is_user_tracked(device.user_id()));
} }
@ -405,10 +366,8 @@ mod test {
async fn test_message_hash() { async fn test_message_hash() {
let store = MemoryStore::new(); let store = MemoryStore::new();
let hash = OlmMessageHash { let hash =
sender_key: "test_sender".to_owned(), OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
hash: "test_hash".to_owned(),
};
let mut changes = Changes::default(); let mut changes = Changes::default();
changes.message_hashes.push(hash.clone()); changes.message_hashes.push(hash.clone());

View File

@ -143,12 +143,7 @@ impl Store {
store: Arc<Box<dyn CryptoStore>>, store: Arc<Box<dyn CryptoStore>>,
verification_machine: VerificationMachine, verification_machine: VerificationMachine,
) -> Self { ) -> Self {
Self { Self { user_id, identity, inner: store, verification_machine }
user_id,
identity,
inner: store,
verification_machine,
}
} }
pub async fn get_readonly_device( pub async fn get_readonly_device(
@ -160,10 +155,7 @@ impl Store {
} }
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes { let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await self.save_changes(changes).await
} }
@ -171,10 +163,7 @@ impl Store {
#[cfg(test)] #[cfg(test)]
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
changed: devices.to_vec(),
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -186,10 +175,7 @@ impl Store {
&self, &self,
sessions: &[InboundGroupSession], sessions: &[InboundGroupSession],
) -> Result<()> { ) -> Result<()> {
let changes = Changes { let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
inbound_group_sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await self.save_changes(changes).await
} }
@ -208,8 +194,7 @@ impl Store {
) -> Result<Option<Device>> { ) -> Result<Option<Device>> {
self.get_user_devices(user_id).await.map(|d| { self.get_user_devices(user_id).await.map(|d| {
d.devices().find(|d| { d.devices().find(|d| {
d.get_key(DeviceKeyAlgorithm::Curve25519) d.get_key(DeviceKeyAlgorithm::Curve25519).map_or(false, |k| k == curve_key)
.map_or(false, |k| k == curve_key)
}) })
}) })
} }
@ -217,12 +202,8 @@ impl Store {
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> { pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?; let devices = self.inner.get_user_devices(user_id).await?;
let own_identity = self let own_identity =
.inner self.inner.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
.get_user_identity(&self.user_id)
.await?
.map(|i| i.own().cloned())
.flatten();
let device_owner_identity = self.inner.get_user_identity(user_id).await.ok().flatten(); let device_owner_identity = self.inner.get_user_identity(user_id).await.ok().flatten();
Ok(UserDevices { Ok(UserDevices {
@ -239,18 +220,11 @@ impl Store {
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<Option<Device>> { ) -> Result<Option<Device>> {
let own_identity = self let own_identity =
.get_user_identity(&self.user_id) self.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
.await?
.map(|i| i.own().cloned())
.flatten();
let device_owner_identity = self.get_user_identity(user_id).await?; let device_owner_identity = self.get_user_identity(user_id).await?;
Ok(self Ok(self.inner.get_device(user_id, device_id).await?.map(|d| Device {
.inner
.get_device(user_id, device_id)
.await?
.map(|d| Device {
inner: d, inner: d,
private_identity: self.identity.clone(), private_identity: self.identity.clone(),
verification_machine: self.verification_machine.clone(), verification_machine: self.verification_machine.clone(),
@ -364,7 +338,8 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Get all the inbound group sessions we have stored. /// Get all the inbound group sessions we have stored.
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>; async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
/// Get the outobund group sessions we have stored that is used for the given room. /// Get the outobund group sessions we have stored that is used for the
/// given room.
async fn get_outbound_group_sessions( async fn get_outbound_group_sessions(
&self, &self,
room_id: &RoomId, room_id: &RoomId,

View File

@ -113,9 +113,7 @@ impl PickleKey {
/// Get a `PicklingMode` version of this pickle key. /// Get a `PicklingMode` version of this pickle key.
pub fn pickle_mode(&self) -> PicklingMode { pub fn pickle_mode(&self) -> PicklingMode {
PicklingMode::Encrypted { PicklingMode::Encrypted { key: self.aes256_key.clone() }
key: self.aes256_key.clone(),
}
} }
/// Get the raw AES256 key. /// Get the raw AES256 key.
@ -141,10 +139,7 @@ impl PickleKey {
getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key"); getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key");
let ciphertext = cipher let ciphertext = cipher
.encrypt( .encrypt(&GenericArray::from_slice(nonce.as_ref()), self.aes256_key.as_slice())
&GenericArray::from_slice(nonce.as_ref()),
self.aes256_key.as_slice(),
)
.expect("Can't encrypt pickle key"); .expect("Can't encrypt pickle key");
EncryptedPickleKey { EncryptedPickleKey {
@ -180,9 +175,7 @@ impl PickleKey {
} }
}; };
Ok(Self { Ok(Self { aes256_key: decrypted })
aes256_key: decrypted,
})
} }
} }

View File

@ -96,13 +96,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) { impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[ [self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
} }
} }
@ -163,9 +157,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path { if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish() f.debug_struct("SledStore").field("path", &path).finish()
} else { } else {
f.debug_struct("SledStore") f.debug_struct("SledStore").field("path", &"memory store").finish()
.field("path", &"memory store")
.finish()
} }
} }
} }
@ -252,9 +244,8 @@ impl SledStore {
} }
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> { fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
let key = if let Some(key) = database let key = if let Some(key) =
.get("pickle_key".encode())? database.get("pickle_key".encode())?.map(|v| serde_json::from_slice(&v))
.map(|v| serde_json::from_slice(&v))
{ {
PickleKey::from_encrypted(passphrase, key?) PickleKey::from_encrypted(passphrase, key?)
.map_err(|_| CryptoStoreError::UnpicklingError)? .map_err(|_| CryptoStoreError::UnpicklingError)?
@ -296,9 +287,7 @@ impl SledStore {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> { ) -> Result<Option<OutboundGroupSession>> {
let account_info = self let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
self.outbound_group_sessions self.outbound_group_sessions
.get(room_id.encode())? .get(room_id.encode())?
@ -500,17 +489,11 @@ impl SledStore {
&self, &self,
id: &[u8], id: &[u8],
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
let request = self let request =
.outgoing_key_requests self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?;
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?;
let request = if request.is_none() { let request = if request.is_none() {
self.unsent_key_requests self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?
} else { } else {
request request
}; };
@ -552,10 +535,7 @@ impl CryptoStore for SledStore {
*self.account_info.write().unwrap() = Some(account_info); *self.account_info.write().unwrap() = Some(account_info);
let changes = Changes { let changes = Changes { account: Some(account), ..Default::default() };
account: Some(account),
..Default::default()
};
self.save_changes(changes).await self.save_changes(changes).await
} }
@ -578,9 +558,7 @@ impl CryptoStore for SledStore {
} }
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> { async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
let account_info = self let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
if self.session_cache.get(sender_key).is_none() { if self.session_cache.get(sender_key).is_none() {
let sessions: Result<Vec<Session>> = self let sessions: Result<Vec<Session>> = self
@ -612,16 +590,10 @@ impl CryptoStore for SledStore {
session_id: &str, session_id: &str,
) -> Result<Option<InboundGroupSession>> { ) -> Result<Option<InboundGroupSession>> {
let key = (room_id.as_str(), sender_key, session_id).encode(); let key = (room_id.as_str(), sender_key, session_id).encode();
let pickle = self let pickle = self.inbound_group_sessions.get(&key)?.map(|p| serde_json::from_slice(&p));
.inbound_group_sessions
.get(&key)?
.map(|p| serde_json::from_slice(&p));
if let Some(pickle) = pickle { if let Some(pickle) = pickle {
Ok(Some(InboundGroupSession::from_pickle( Ok(Some(InboundGroupSession::from_pickle(pickle?, self.get_pickle_mode())?))
pickle?,
self.get_pickle_mode(),
)?))
} else { } else {
Ok(None) Ok(None)
} }
@ -657,10 +629,7 @@ impl CryptoStore for SledStore {
fn users_for_key_query(&self) -> HashSet<UserId> { fn users_for_key_query(&self) -> HashSet<UserId> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
self.users_for_key_query_cache self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
.iter()
.map(|u| u.clone())
.collect()
} }
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> { async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
@ -714,9 +683,7 @@ impl CryptoStore for SledStore {
} }
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> { async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
Ok(self Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?)
.olm_hashes
.contains_key(serde_json::to_vec(message_hash)?)?)
} }
async fn get_outgoing_key_request( async fn get_outgoing_key_request(
@ -752,11 +719,8 @@ impl CryptoStore for SledStore {
} }
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
let ret: Result<(), TransactionError<serde_json::Error>> = ( let ret: Result<(), TransactionError<serde_json::Error>> =
&self.outgoing_key_requests, (&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info)
&self.unsent_key_requests,
&self.key_requests_by_info,
)
.transaction( .transaction(
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
@ -846,10 +810,7 @@ mod test {
async fn get_loaded_store() -> (ReadOnlyAccount, SledStore, tempfile::TempDir) { async fn get_loaded_store() -> (ReadOnlyAccount, SledStore, tempfile::TempDir) {
let (store, dir) = get_store(None).await; let (store, dir) = get_store(None).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
(account, store, dir) (account, store, dir)
} }
@ -863,21 +824,12 @@ mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await; bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob let one_time_key =
.one_time_keys() bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new()); let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned(); let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice let session =
.create_outbound_session_helper(&sender_key, &one_time_key) alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
.await
.unwrap();
(alice, session) (alice, session)
} }
@ -895,10 +847,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none()); assert!(store.load_account().await.unwrap().is_none());
let account = get_account(); let account = get_account();
store store.save_account(account).await.expect("Can't save account");
.save_account(account)
.await
.expect("Can't save account");
} }
#[async_test] #[async_test]
@ -906,10 +855,7 @@ mod test {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
@ -922,10 +868,7 @@ mod test {
let (store, _dir) = get_store(Some("secret_passphrase")).await; let (store, _dir) = get_store(Some("secret_passphrase")).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
@ -938,50 +881,32 @@ mod test {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
account.mark_as_shared(); account.mark_as_shared();
account.update_uploaded_key_count(50); account.update_uploaded_key_count(50);
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
assert_eq!(account, loaded_account); assert_eq!(account, loaded_account);
assert_eq!( assert_eq!(account.uploaded_key_count(), loaded_account.uploaded_key_count());
account.uploaded_key_count(),
loaded_account.uploaded_key_count()
);
} }
#[async_test] #[async_test]
async fn load_sessions() { async fn load_sessions() {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let changes = Changes { let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
sessions: vec![session.clone()],
..Default::default()
};
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let sessions = store let sessions =
.get_sessions(&session.sender_key) store.get_sessions(&session.sender_key).await.expect("Can't load sessions").unwrap();
.await
.expect("Can't load sessions")
.unwrap();
let loaded_session = sessions.lock().await.get(0).cloned().unwrap(); let loaded_session = sessions.lock().await.get(0).cloned().unwrap();
assert_eq!(&session, &loaded_session); assert_eq!(&session, &loaded_session);
@ -994,15 +919,9 @@ mod test {
let sender_key = session.sender_key.to_owned(); let sender_key = session.sender_key.to_owned();
let session_id = session.session_id().to_owned(); let session_id = session.session_id().to_owned();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let changes = Changes { let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
sessions: vec![session.clone()],
..Default::default()
};
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
@ -1040,15 +959,9 @@ mod test {
) )
.expect("Can't create session"); .expect("Can't create session");
let changes = Changes { let changes = Changes { inbound_group_sessions: vec![session], ..Default::default() };
inbound_group_sessions: vec![session],
..Default::default()
};
store store.save_changes(changes).await.expect("Can't save group session");
.save_changes(changes)
.await
.expect("Can't save group session");
} }
#[async_test] #[async_test]
@ -1072,15 +985,10 @@ mod test {
let session = InboundGroupSession::from_export(export).unwrap(); let session = InboundGroupSession::from_export(export).unwrap();
let changes = Changes { let changes =
inbound_group_sessions: vec![session.clone()], Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() };
..Default::default()
};
store store.save_changes(changes).await.expect("Can't save group session");
.save_changes(changes)
.await
.expect("Can't save group session");
drop(store); drop(store);
@ -1103,21 +1011,12 @@ mod test {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
assert!(store assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
.update_tracked_user(device.user_id(), false) assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.is_user_tracked(device.user_id())); assert!(store.is_user_tracked(device.user_id()));
assert!(!store.users_for_key_query().contains(device.user_id())); assert!(!store.users_for_key_query().contains(device.user_id()));
assert!(!store assert!(!store.update_tracked_user(device.user_id(), true).await.unwrap());
.update_tracked_user(device.user_id(), true)
.await
.unwrap());
assert!(store.users_for_key_query().contains(device.user_id())); assert!(store.users_for_key_query().contains(device.user_id()));
drop(store); drop(store);
@ -1128,10 +1027,7 @@ mod test {
assert!(store.is_user_tracked(device.user_id())); assert!(store.is_user_tracked(device.user_id()));
assert!(store.users_for_key_query().contains(device.user_id())); assert!(store.users_for_key_query().contains(device.user_id()));
store store.update_tracked_user(device.user_id(), false).await.unwrap();
.update_tracked_user(device.user_id(), false)
.await
.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id())); assert!(!store.users_for_key_query().contains(device.user_id()));
drop(store); drop(store);
@ -1148,10 +1044,7 @@ mod test {
let device = get_device(); let device = get_device();
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
changed: vec![device.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -1163,11 +1056,8 @@ mod test {
store.load_account().await.unwrap(); store.load_account().await.unwrap();
let loaded_device = store let loaded_device =
.get_device(device.user_id(), device.device_id()) store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(device, loaded_device); assert_eq!(device, loaded_device);
@ -1188,20 +1078,14 @@ mod test {
let device = get_device(); let device = get_device();
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
changed: vec![device.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { deleted: vec![device.clone()], ..Default::default() },
deleted: vec![device.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -1212,10 +1096,7 @@ mod test {
store.load_account().await.unwrap(); store.load_account().await.unwrap();
let loaded_device = store let loaded_device = store.get_device(device.user_id(), device.device_id()).await.unwrap();
.get_device(device.user_id(), device.device_id())
.await
.unwrap();
assert!(loaded_device.is_none()); assert!(loaded_device.is_none());
} }
@ -1232,10 +1113,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id); let account = ReadOnlyAccount::new(&user_id, &device_id);
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let own_identity = get_own_identity(); let own_identity = get_own_identity();
@ -1247,10 +1125,7 @@ mod test {
..Default::default() ..Default::default()
}; };
store store.save_changes(changes).await.expect("Can't save identity");
.save_changes(changes)
.await
.expect("Can't save identity");
drop(store); drop(store);
@ -1258,17 +1133,10 @@ mod test {
store.load_account().await.unwrap(); store.load_account().await.unwrap();
let loaded_user = store let loaded_user = store.get_user_identity(own_identity.user_id()).await.unwrap().unwrap();
.get_user_identity(own_identity.user_id())
.await
.unwrap()
.unwrap();
assert_eq!(loaded_user.master_key(), own_identity.master_key()); assert_eq!(loaded_user.master_key(), own_identity.master_key());
assert_eq!( assert_eq!(loaded_user.self_signing_key(), own_identity.self_signing_key());
loaded_user.self_signing_key(),
own_identity.self_signing_key()
);
assert_eq!(loaded_user, own_identity.clone().into()); assert_eq!(loaded_user, own_identity.clone().into());
let other_identity = get_other_identity(); let other_identity = get_other_identity();
@ -1283,17 +1151,10 @@ mod test {
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let loaded_user = store let loaded_user = store.get_user_identity(other_identity.user_id()).await.unwrap().unwrap();
.get_user_identity(other_identity.user_id())
.await
.unwrap()
.unwrap();
assert_eq!(loaded_user.master_key(), other_identity.master_key()); assert_eq!(loaded_user.master_key(), other_identity.master_key());
assert_eq!( assert_eq!(loaded_user.self_signing_key(), other_identity.self_signing_key());
loaded_user.self_signing_key(),
other_identity.self_signing_key()
);
assert_eq!(loaded_user, other_identity.into()); assert_eq!(loaded_user, other_identity.into());
own_identity.mark_as_verified(); own_identity.mark_as_verified();
@ -1317,10 +1178,7 @@ mod test {
assert!(store.load_identity().await.unwrap().is_none()); assert!(store.load_identity().await.unwrap().is_none());
let identity = PrivateCrossSigningIdentity::new(alice_id()).await; let identity = PrivateCrossSigningIdentity::new(alice_id()).await;
let changes = Changes { let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
private_identity: Some(identity.clone()),
..Default::default()
};
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let loaded_identity = store.load_identity().await.unwrap().unwrap(); let loaded_identity = store.load_identity().await.unwrap().unwrap();
@ -1331,10 +1189,8 @@ mod test {
async fn olm_hash_saving() { async fn olm_hash_saving() {
let (_, store, _dir) = get_loaded_store().await; let (_, store, _dir) = get_loaded_store().await;
let hash = OlmMessageHash { let hash =
sender_key: "test_sender".to_owned(), OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
hash: "test_hash".to_owned(),
};
let mut changes = Changes::default(); let mut changes = Changes::default();
changes.message_hashes.push(hash.clone()); changes.message_hashes.push(hash.clone());

View File

@ -83,17 +83,13 @@ impl VerificationMachine {
); );
let request = match content.into() { let request = match content.into() {
OutgoingContent::Room(r, c) => RoomMessageRequest { OutgoingContent::Room(r, c) => {
room_id: r, RoomMessageRequest { room_id: r, txn_id: Uuid::new_v4(), content: c }.into()
txn_id: Uuid::new_v4(),
content: c,
} }
.into(),
OutgoingContent::ToDevice(c) => { OutgoingContent::ToDevice(c) => {
let request = content_to_request(device.user_id(), device.device_id(), c); let request = content_to_request(device.user_id(), device.device_id(), c);
self.verifications self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
request.into() request.into()
} }
@ -134,10 +130,7 @@ impl VerificationMachine {
let request = content_to_request(recipient, recipient_device, c); let request = content_to_request(recipient, recipient_device, c);
let request_id = request.txn_id; let request_id = request.txn_id;
let request = OutgoingRequest { let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };
request_id,
request: Arc::new(request.into()),
};
self.outgoing_to_device_messages.insert(request_id, request); self.outgoing_to_device_messages.insert(request_id, request);
} }
@ -147,12 +140,7 @@ impl VerificationMachine {
let request = OutgoingRequest { let request = OutgoingRequest {
request: Arc::new( request: Arc::new(
RoomMessageRequest { RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(),
room_id: r,
txn_id: request_id,
content: c,
}
.into(),
), ),
request_id, request_id,
}; };
@ -181,32 +169,22 @@ impl VerificationMachine {
} }
pub fn outgoing_room_message_requests(&self) -> Vec<OutgoingRequest> { pub fn outgoing_room_message_requests(&self) -> Vec<OutgoingRequest> {
self.outgoing_room_messages self.outgoing_room_messages.iter().map(|r| (*r).clone()).collect()
.iter()
.map(|r| (*r).clone())
.collect()
} }
pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> { pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
self.outgoing_to_device_messages self.outgoing_to_device_messages.iter().map(|r| (*r).clone()).collect()
.iter()
.map(|r| (*r).clone())
.collect()
} }
pub fn garbage_collect(&self) { pub fn garbage_collect(&self) {
self.verifications self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
.retain(|_, s| !(s.is_done() || s.is_canceled()));
for sas in self.verifications.iter() { for sas in self.verifications.iter() {
if let Some(r) = sas.cancel_if_timed_out() { if let Some(r) = sas.cancel_if_timed_out() {
self.outgoing_to_device_messages.insert( self.outgoing_to_device_messages.insert(
r.request_id(), r.request_id(),
OutgoingRequest { OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
request_id: r.request_id(),
request: Arc::new(r.into()),
},
); );
} }
} }
@ -266,10 +244,8 @@ impl VerificationMachine {
); );
if let Some((_, request)) = self.requests.remove(&e.content.relation.event_id) { if let Some((_, request)) = self.requests.remove(&e.content.relation.event_id) {
if let Some(d) = self if let Some(d) =
.store self.store.get_device(&e.sender, &e.content.from_device).await?
.get_device(&e.sender, &e.content.from_device)
.await?
{ {
match request.into_started_sas( match request.into_started_sas(
e, e,
@ -296,7 +272,8 @@ impl VerificationMachine {
"Can't start key verification with {} {}, canceling: {:?}", "Can't start key verification with {} {}, canceling: {:?}",
e.sender, e.content.from_device, c e.sender, e.content.from_device, c
); );
// self.queue_up_content(&e.sender, &e.content.from_device, c) // self.queue_up_content(&e.sender,
// &e.content.from_device, c)
} }
} }
} }
@ -375,11 +352,7 @@ impl VerificationMachine {
e.content.from_device e.content.from_device
); );
if let Some(d) = self if let Some(d) = self.store.get_device(&e.sender, &e.content.from_device).await? {
.store
.get_device(&e.sender, &e.content.from_device)
.await?
{
let private_identity = self.private_identity.lock().await.clone(); let private_identity = self.private_identity.lock().await.clone();
match Sas::from_start_event( match Sas::from_start_event(
self.account.clone(), self.account.clone(),
@ -390,8 +363,7 @@ impl VerificationMachine {
self.store.get_user_identity(&e.sender).await?, self.store.get_user_identity(&e.sender).await?,
) { ) {
Ok(s) => { Ok(s) => {
self.verifications self.verifications.insert(e.content.transaction_id.clone(), s);
.insert(e.content.transaction_id.clone(), s);
} }
Err(c) => { Err(c) => {
warn!( warn!(
@ -442,10 +414,7 @@ impl VerificationMachine {
self.outgoing_to_device_messages.insert( self.outgoing_to_device_messages.insert(
request_id, request_id,
OutgoingRequest { OutgoingRequest { request_id, request: Arc::new(r.into()) },
request_id,
request: Arc::new(r.into()),
},
); );
} }
} }
@ -521,10 +490,7 @@ mod test {
); );
machine machine
.receive_event(&wrap_any_to_device_content( .receive_event(&wrap_any_to_device_content(bob_sas.user_id(), start_content.into()))
bob_sas.user_id(),
start_content.into(),
))
.await .await
.unwrap(); .unwrap();
@ -559,11 +525,7 @@ mod test {
alice_machine.receive_event(&event).await.unwrap(); alice_machine.receive_event(&event).await.unwrap();
assert!(!alice_machine.outgoing_to_device_messages.is_empty()); assert!(!alice_machine.outgoing_to_device_messages.is_empty());
let request = alice_machine let request = alice_machine.outgoing_to_device_messages.iter().next().unwrap();
.outgoing_to_device_messages
.iter()
.next()
.unwrap();
let txn_id = *request.request_id(); let txn_id = *request.request_id();

View File

@ -56,11 +56,7 @@ pub(crate) mod test {
sender: &UserId, sender: &UserId,
content: OutgoingContent, content: OutgoingContent,
) -> AnyToDeviceEvent { ) -> AnyToDeviceEvent {
let content = if let OutgoingContent::ToDevice(c) = content { let content = if let OutgoingContent::ToDevice(c) = content { c } else { unreachable!() };
c
} else {
unreachable!()
};
match content { match content {
AnyToDeviceEventContent::KeyVerificationKey(c) => { AnyToDeviceEventContent::KeyVerificationKey(c) => {
@ -95,22 +91,11 @@ pub(crate) mod test {
pub(crate) fn get_content_from_request( pub(crate) fn get_content_from_request(
request: &OutgoingVerificationRequest, request: &OutgoingVerificationRequest,
) -> OutgoingContent { ) -> OutgoingContent {
let request = if let OutgoingVerificationRequest::ToDevice(r) = request { let request =
r if let OutgoingVerificationRequest::ToDevice(r) = request { r } else { unreachable!() };
} else {
unreachable!()
};
let json: Value = serde_json::from_str( let json: Value = serde_json::from_str(
request request.messages.values().next().unwrap().values().next().unwrap().get(),
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
) )
.unwrap(); .unwrap();

View File

@ -106,15 +106,13 @@ impl VerificationRequest {
content: &KeyVerificationRequestEventContent, content: &KeyVerificationRequestEventContent,
) -> Self { ) -> Self {
Self { Self {
inner: Arc::new(Mutex::new(InnerRequest::Requested( inner: Arc::new(Mutex::new(InnerRequest::Requested(RequestState::from_request_event(
RequestState::from_request_event(
account.user_id(), account.user_id(),
account.device_id(), account.device_id(),
sender, sender,
event_id, event_id,
content, content,
), )))),
))),
account, account,
other_user_id: sender.clone().into(), other_user_id: sender.clone().into(),
private_cross_signing_identity, private_cross_signing_identity,
@ -285,10 +283,7 @@ impl RequestState<Created> {
own_user_id: self.own_user_id, own_user_id: self.own_user_id,
own_device_id: self.own_device_id, own_device_id: self.own_device_id,
other_user_id: self.other_user_id, other_user_id: self.other_user_id,
state: Sent { state: Sent { methods: SUPPORTED_METHODS.to_vec(), flow_id: response.event_id.clone() },
methods: SUPPORTED_METHODS.to_vec(),
flow_id: response.event_id.clone(),
},
} }
} }
} }
@ -368,9 +363,7 @@ impl RequestState<Requested> {
let content = ReadyEventContent { let content = ReadyEventContent {
from_device: self.own_device_id, from_device: self.own_device_id,
methods: self.state.methods, methods: self.state.methods,
relation: Relation { relation: Relation { event_id: self.state.flow_id },
event_id: self.state.flow_id,
},
}; };
(state, content) (state, content)
@ -580,9 +573,7 @@ mod test {
panic!("Invalid start event content type"); panic!("Invalid start event content type");
}; };
let alice_sas = alice_request let alice_sas = alice_request.into_started_sas(&event, bob_device, None).unwrap();
.into_started_sas(&event, bob_device, None)
.unwrap();
assert!(!bob_sas.is_canceled()); assert!(!bob_sas.is_canceled());
assert!(!alice_sas.is_canceled()); assert!(!alice_sas.is_canceled());

View File

@ -61,10 +61,7 @@ impl StartContent {
StartContent::Room(_, c) => serde_json::to_value(c), StartContent::Room(_, c) => serde_json::to_value(c),
}; };
content content.expect("Can't serialize content").try_into().expect("Can't canonicalize content")
.expect("Can't serialize content")
.try_into()
.expect("Can't canonicalize content")
} }
} }

View File

@ -62,12 +62,7 @@ pub fn calculate_commitment(public_key: &str, content: impl Into<StartContent>)
let content = content.into().canonical_json(); let content = content.into().canonical_json();
let content_string = content.to_string(); let content_string = content.to_string();
encode( encode(Sha256::new().chain(&public_key).chain(&content_string).finalize())
Sha256::new()
.chain(&public_key)
.chain(&content_string)
.finalize(),
)
} }
/// Get a tuple of an emoji and a description of the emoji using a number. /// Get a tuple of an emoji and a description of the emoji using a number.
@ -231,11 +226,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id)) .calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC") .expect("Can't calculate SAS MAC")
{ {
trace!( trace!("Successfully verified the device key {} from {}", key_id, sender);
"Successfully verified the device key {} from {}",
key_id,
sender
);
verified_devices.push(ids.other_device.clone()); verified_devices.push(ids.other_device.clone());
} else { } else {
@ -250,11 +241,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id)) .calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC") .expect("Can't calculate SAS MAC")
{ {
trace!( trace!("Successfully verified the master key {} from {}", key_id, sender);
"Successfully verified the master key {} from {}",
key_id,
sender
);
verified_identities.push(identity.clone()) verified_identities.push(identity.clone())
} else { } else {
return Err(CancelCode::KeyMismatch); return Err(CancelCode::KeyMismatch);
@ -316,8 +303,7 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
mac.insert( mac.insert(
key_id.to_string(), key_id.to_string(),
sas.calculate_mac(key, &format!("{}{}", info, key_id)) sas.calculate_mac(key, &format!("{}{}", info, key_id)).expect("Can't calculate SAS MAC"),
.expect("Can't calculate SAS MAC"),
); );
// TODO Add the cross signing master key here if we trust/have it. // TODO Add the cross signing master key here if we trust/have it.
@ -329,23 +315,13 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
.expect("Can't calculate SAS MAC"); .expect("Can't calculate SAS MAC");
match flow_id { match flow_id {
FlowId::ToDevice(s) => MacToDeviceEventContent { FlowId::ToDevice(s) => {
transaction_id: s.to_string(), MacToDeviceEventContent { transaction_id: s.to_string(), keys, mac }.into()
keys, }
mac, FlowId::InRoom(r, e) => {
(r.clone(), MacEventContent { mac, keys, relation: Relation { event_id: e.clone() } })
.into()
} }
.into(),
FlowId::InRoom(r, e) => (
r.clone(),
MacEventContent {
mac,
keys,
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
} }
} }
@ -366,24 +342,12 @@ fn extra_info_sas(
flow_id: &str, flow_id: &str,
we_started: bool, we_started: bool,
) -> String { ) -> String {
let our_info = format!( let our_info = format!("{}|{}|{}", ids.account.user_id(), ids.account.device_id(), own_pubkey);
"{}|{}|{}", let their_info =
ids.account.user_id(), format!("{}|{}|{}", ids.other_device.user_id(), ids.other_device.device_id(), their_pubkey);
ids.account.device_id(),
own_pubkey
);
let their_info = format!(
"{}|{}|{}",
ids.other_device.user_id(),
ids.other_device.device_id(),
their_pubkey
);
let (first_info, second_info) = if we_started { let (first_info, second_info) =
(our_info, their_info) if we_started { (our_info, their_info) } else { (their_info, our_info) };
} else {
(their_info, our_info)
};
let info = format!( let info = format!(
"MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}", "MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}",
@ -580,11 +544,7 @@ pub fn content_to_request(
_ => unreachable!(), _ => unreachable!(),
}; };
ToDeviceRequest { ToDeviceRequest { txn_id: Uuid::new_v4(), event_type, messages }
txn_id: Uuid::new_v4(),
event_type,
messages,
}
} }
#[cfg(test)] #[cfg(test)]
@ -622,18 +582,14 @@ mod test {
#[test] #[test]
fn emoji_generation() { fn emoji_generation() {
let bytes = vec![0, 0, 0, 0, 0, 0]; let bytes = vec![0, 0, 0, 0, 0, 0];
let index: Vec<(&'static str, &'static str)> = vec![0, 0, 0, 0, 0, 0, 0] let index: Vec<(&'static str, &'static str)> =
.into_iter() vec![0, 0, 0, 0, 0, 0, 0].into_iter().map(emoji_from_index).collect();
.map(emoji_from_index)
.collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref()); assert_eq!(bytes_to_emoji(bytes), index.as_ref());
let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]; let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let index: Vec<(&'static str, &'static str)> = vec![63, 63, 63, 63, 63, 63, 63] let index: Vec<(&'static str, &'static str)> =
.into_iter() vec![63, 63, 63, 63, 63, 63, 63].into_iter().map(emoji_from_index).collect();
.map(emoji_from_index)
.collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref()); assert_eq!(bytes_to_emoji(bytes), index.as_ref());
} }

View File

@ -278,14 +278,15 @@ impl InnerSas {
_ => (self, None), _ => (self, None),
}, },
AnyToDeviceEvent::KeyVerificationMac(e) => match self { AnyToDeviceEvent::KeyVerificationMac(e) => match self {
InnerSas::KeyRecieved(s) => match s.into_mac_received(&e.sender, e.content.clone()) InnerSas::KeyRecieved(s) => {
{ match s.into_mac_received(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::MacReceived(s), None), Ok(s) => (InnerSas::MacReceived(s), None),
Err(s) => { Err(s) => {
let content = s.as_content(); let content = s.as_content();
(InnerSas::Canceled(s), Some(content.into())) (InnerSas::Canceled(s), Some(content.into()))
} }
}, }
}
InnerSas::Confirmed(s) => match s.into_done(&e.sender, e.content.clone()) { InnerSas::Confirmed(s) => match s.into_done(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::Done(s), None), Ok(s) => (InnerSas::Done(s), None),
Err(s) => { Err(s) => {

View File

@ -150,11 +150,8 @@ impl Sas {
store: Arc<Box<dyn CryptoStore>>, store: Arc<Box<dyn CryptoStore>>,
other_identity: Option<UserIdentities>, other_identity: Option<UserIdentities>,
) -> (Sas, StartContent) { ) -> (Sas, StartContent) {
let (inner, content) = InnerSas::start( let (inner, content) =
account.clone(), InnerSas::start(account.clone(), other_device.clone(), other_identity.clone());
other_device.clone(),
other_identity.clone(),
);
( (
Self::start_helper( Self::start_helper(
@ -266,11 +263,7 @@ impl Sas {
&self, &self,
settings: AcceptSettings, settings: AcceptSettings,
) -> Option<OutgoingVerificationRequest> { ) -> Option<OutgoingVerificationRequest> {
self.inner self.inner.lock().unwrap().accept().map(|c| match settings.apply(c) {
.lock()
.unwrap()
.accept()
.map(|c| match settings.apply(c) {
AcceptContent::ToDevice(c) => { AcceptContent::ToDevice(c) => {
let content = AnyToDeviceEventContent::KeyVerificationAccept(c); let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
self.content_to_request(content).into() self.content_to_request(content).into()
@ -294,10 +287,7 @@ impl Sas {
pub async fn confirm( pub async fn confirm(
&self, &self,
) -> Result< ) -> Result<
( (Option<OutgoingVerificationRequest>, Option<SignatureUploadRequest>),
Option<OutgoingVerificationRequest>,
Option<SignatureUploadRequest>,
),
CryptoStoreError, CryptoStoreError,
> { > {
let (content, done) = { let (content, done) = {
@ -310,9 +300,9 @@ impl Sas {
}; };
let mac_request = content.map(|c| match c { let mac_request = content.map(|c| match c {
event_enums::MacContent::ToDevice(c) => self event_enums::MacContent::ToDevice(c) => {
.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c)) self.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c)).into()
.into(), }
event_enums::MacContent::Room(r, c) => RoomMessageRequest { event_enums::MacContent::Room(r, c) => RoomMessageRequest {
room_id: r, room_id: r,
txn_id: Uuid::new_v4(), txn_id: Uuid::new_v4(),
@ -365,10 +355,7 @@ impl Sas {
}; };
let mut changes = Changes { let mut changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![device], ..Default::default() },
changed: vec![device],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -428,10 +415,7 @@ impl Sas {
.map(VerificationResult::SignatureUpload) .map(VerificationResult::SignatureUpload)
.unwrap_or(VerificationResult::Ok)) .unwrap_or(VerificationResult::Ok))
} else { } else {
Ok(self Ok(self.cancel().map(VerificationResult::Cancel).unwrap_or(VerificationResult::Ok))
.cancel()
.map(VerificationResult::Cancel)
.unwrap_or(VerificationResult::Ok))
} }
} }
@ -454,14 +438,8 @@ impl Sas {
.as_ref() .as_ref()
.map_or(false, |i| i.master_key() == identity.master_key()) .map_or(false, |i| i.master_key() == identity.master_key())
{ {
if self if self.verified_identities().map_or(false, |i| i.contains(&identity)) {
.verified_identities() trace!("Marking user identity of {} as verified.", identity.user_id(),);
.map_or(false, |i| i.contains(&identity))
{
trace!(
"Marking user identity of {} as verified.",
identity.user_id(),
);
if let UserIdentities::Own(i) = &identity { if let UserIdentities::Own(i) = &identity {
i.mark_as_verified(); i.mark_as_verified();
@ -500,17 +478,11 @@ impl Sas {
pub(crate) async fn mark_device_as_verified( pub(crate) async fn mark_device_as_verified(
&self, &self,
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> { ) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self let device = self.store.get_device(self.other_user_id(), self.other_device_id()).await?;
.store
.get_device(self.other_user_id(), self.other_device_id())
.await?;
if let Some(device) = device { if let Some(device) = device {
if device.keys() == self.other_device.keys() { if device.keys() == self.other_device.keys() {
if self if self.verified_devices().map_or(false, |v| v.contains(&device)) {
.verified_devices()
.map_or(false, |v| v.contains(&device))
{
trace!( trace!(
"Marking device {} {} as verified.", "Marking device {} {} as verified.",
device.user_id(), device.user_id(),
@ -571,9 +543,9 @@ impl Sas {
content: AnyMessageEventContent::KeyVerificationCancel(content), content: AnyMessageEventContent::KeyVerificationCancel(content),
} }
.into(), .into(),
CancelContent::ToDevice(c) => self CancelContent::ToDevice(c) => {
.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c)) self.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c)).into()
.into(), }
}) })
} }
@ -704,9 +676,7 @@ impl AcceptSettings {
/// ///
/// * `methods` - The methods this client allows at most /// * `methods` - The methods this client allows at most
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self { pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
Self { Self { allowed_methods: methods }
allowed_methods: methods,
}
} }
fn apply(self, mut content: AcceptContent) -> AcceptContent { fn apply(self, mut content: AcceptContent) -> AcceptContent {
@ -715,15 +685,8 @@ impl AcceptSettings {
method: AcceptMethod::MSasV1(c), method: AcceptMethod::MSasV1(c),
.. ..
}) })
| AcceptContent::Room( | AcceptContent::Room(_, AcceptEventContent { method: AcceptMethod::MSasV1(c), .. }) => {
_, c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
AcceptEventContent {
method: AcceptMethod::MSasV1(c),
..
},
) => {
c.short_authentication_string
.retain(|sas| self.allowed_methods.contains(sas));
content content
} }
_ => content, _ => content,
@ -826,13 +789,7 @@ mod test {
); );
alice.receive_event(&event); alice.receive_event(&event);
assert!(alice assert!(alice.verified_devices().unwrap().contains(&alice.other_device()));
.verified_devices() assert!(bob.verified_devices().unwrap().contains(&bob.other_device()));
.unwrap()
.contains(&alice.other_device()));
assert!(bob
.verified_devices()
.unwrap()
.contains(&bob.other_device()));
} }
} }

View File

@ -59,10 +59,8 @@ const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
&[KeyAgreementProtocol::Curve25519HkdfSha256]; &[KeyAgreementProtocol::Curve25519HkdfSha256];
const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256]; const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
const MACS: &[MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256]; const MACS: &[MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256];
const STRINGS: &[ShortAuthenticationString] = &[ const STRINGS: &[ShortAuthenticationString] =
ShortAuthenticationString::Decimal, &[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
ShortAuthenticationString::Emoji,
];
// The max time a SAS flow can take from start to done. // The max time a SAS flow can take from start to done.
const MAX_AGE: Duration = Duration::from_secs(60 * 5); const MAX_AGE: Duration = Duration::from_secs(60 * 5);
@ -111,9 +109,7 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol) if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|| !HASHES.contains(&content.hash) || !HASHES.contains(&content.hash)
|| !MACS.contains(&content.message_authentication_code) || !MACS.contains(&content.message_authentication_code)
|| (!content || (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content && !content
.short_authentication_string .short_authentication_string
.contains(&ShortAuthenticationString::Decimal)) .contains(&ShortAuthenticationString::Decimal))
@ -372,11 +368,7 @@ impl SasState<Created> {
) -> SasState<Created> { ) -> SasState<Created> {
SasState { SasState {
inner: Arc::new(Mutex::new(OlmSas::new())), inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds { ids: SasIds { account, other_device, other_identity },
account,
other_device,
other_identity,
},
verification_flow_id: flow_id.into(), verification_flow_id: flow_id.into(),
creation_time: Arc::new(Instant::now()), creation_time: Arc::new(Instant::now()),
@ -411,9 +403,7 @@ impl SasState<Created> {
MSasV1Content::new(self.state.protocol_definitions.clone()) MSasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."), .expect("Invalid initial protocol definitions."),
), ),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
), ),
} }
@ -460,8 +450,8 @@ impl SasState<Created> {
} }
impl SasState<Started> { impl SasState<Started> {
/// Create a new SAS verification flow from an in-room m.key.verification.start /// Create a new SAS verification flow from an in-room
/// event. /// m.key.verification.start event.
/// ///
/// This will put us in the `started` state. /// This will put us in the `started` state.
/// ///
@ -502,11 +492,7 @@ impl SasState<Started> {
let sas = SasState { let sas = SasState {
inner: Arc::new(Mutex::new(sas)), inner: Arc::new(Mutex::new(sas)),
ids: SasIds { ids: SasIds { account, other_device, other_identity },
account,
other_device,
other_identity,
},
creation_time: Arc::new(Instant::now()), creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()), last_event_time: Arc::new(Instant::now()),
@ -544,11 +530,7 @@ impl SasState<Started> {
creation_time: Arc::new(Instant::now()), creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()), last_event_time: Arc::new(Instant::now()),
ids: SasIds { ids: SasIds { account, other_device, other_identity },
account,
other_device,
other_identity,
},
verification_flow_id: content.flow_id().into(), verification_flow_id: content.flow_id().into(),
state: Arc::new(Canceled::new(CancelCode::UnknownMethod)), state: Arc::new(Canceled::new(CancelCode::UnknownMethod)),
@ -582,19 +564,12 @@ impl SasState<Started> {
); );
match self.verification_flow_id.as_ref() { match self.verification_flow_id.as_ref() {
FlowId::ToDevice(s) => AcceptToDeviceEventContent { FlowId::ToDevice(s) => {
transaction_id: s.to_string(), AcceptToDeviceEventContent { transaction_id: s.to_string(), method }.into()
method,
} }
.into(),
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
r.clone(), r.clone(),
AcceptEventContent { AcceptEventContent { method, relation: Relation { event_id: e.clone() } },
method,
relation: Relation {
event_id: e.clone(),
},
},
) )
.into(), .into(),
} }
@ -662,10 +637,8 @@ impl SasState<Accepted> {
self.check_event(&sender, content.flow_id().as_str()) self.check_event(&sender, content.flow_id().as_str())
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(c))?;
let commitment = calculate_commitment( let commitment =
content.public_key(), calculate_commitment(content.public_key(), self.state.start_content.as_ref().clone());
self.state.start_content.as_ref().clone(),
);
if self.state.commitment != commitment { if self.state.commitment != commitment {
Err(self.cancel(CancelCode::InvalidMessage)) Err(self.cancel(CancelCode::InvalidMessage))
@ -707,9 +680,7 @@ impl SasState<Accepted> {
r.clone(), r.clone(),
KeyEventContent { KeyEventContent {
key: self.inner.lock().unwrap().public_key(), key: self.inner.lock().unwrap().public_key(),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
) )
.into(), .into(),
@ -733,9 +704,7 @@ impl SasState<KeyReceived> {
r.clone(), r.clone(),
KeyEventContent { KeyEventContent {
key: self.inner.lock().unwrap().public_key(), key: self.inner.lock().unwrap().public_key(),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
) )
.into(), .into(),
@ -758,8 +727,8 @@ impl SasState<KeyReceived> {
/// Get the index of the emoji of the short authentication string. /// Get the index of the emoji of the short authentication string.
/// ///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers /// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// can be converted to a unique emoji defined by the spec. /// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] { pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index( get_emoji_index(
&self.inner.lock().unwrap(), &self.inner.lock().unwrap(),
@ -930,11 +899,7 @@ impl SasState<Confirmed> {
/// ///
/// The content needs to be automatically sent to the other side. /// The content needs to be automatically sent to the other side.
pub fn as_content(&self) -> MacContent { pub fn as_content(&self) -> MacContent {
get_mac_content( get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
} }
@ -993,8 +958,8 @@ impl SasState<MacReceived> {
/// Get the index of the emoji of the short authentication string. /// Get the index of the emoji of the short authentication string.
/// ///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers /// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// can be converted to a unique emoji defined by the spec. /// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] { pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index( get_emoji_index(
&self.inner.lock().unwrap(), &self.inner.lock().unwrap(),
@ -1026,11 +991,7 @@ impl SasState<WaitingForDone> {
/// The content needs to be automatically sent to the other side if it /// The content needs to be automatically sent to the other side if it
/// wasn't already sent. /// wasn't already sent.
pub fn as_content(&self) -> MacContent { pub fn as_content(&self) -> MacContent {
get_mac_content( get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
pub fn done_content(&self) -> DoneContent { pub fn done_content(&self) -> DoneContent {
@ -1038,15 +999,9 @@ impl SasState<WaitingForDone> {
FlowId::ToDevice(_) => { FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications") unreachable!("The done content isn't supported yet for to-device verifications")
} }
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => {
r.clone(), (r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
DoneEventContent { }
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
} }
} }
@ -1088,11 +1043,7 @@ impl SasState<Done> {
/// The content needs to be automatically sent to the other side if it /// The content needs to be automatically sent to the other side if it
/// wasn't already sent. /// wasn't already sent.
pub fn as_content(&self) -> MacContent { pub fn as_content(&self) -> MacContent {
get_mac_content( get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
pub fn done_content(&self) -> DoneContent { pub fn done_content(&self) -> DoneContent {
@ -1100,15 +1051,9 @@ impl SasState<Done> {
FlowId::ToDevice(_) => { FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications") unreachable!("The done content isn't supported yet for to-device verifications")
} }
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => {
r.clone(), (r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
DoneEventContent { }
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
} }
} }
@ -1144,10 +1089,7 @@ impl Canceled {
_ => unimplemented!(), _ => unimplemented!(),
}; };
Canceled { Canceled { cancel_code: code, reason }
cancel_code: code,
reason,
}
} }
} }
@ -1166,9 +1108,7 @@ impl SasState<Canceled> {
CancelEventContent { CancelEventContent {
reason: self.state.reason.to_string(), reason: self.state.reason.to_string(),
code: self.state.cancel_code.clone(), code: self.state.cancel_code.clone(),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
) )
.into(), .into(),
@ -1331,9 +1271,7 @@ mod test {
let content = bob.as_content(); let content = bob.as_content();
let sender = UserId::try_from("@malory:example.org").unwrap(); let sender = UserId::try_from("@malory:example.org").unwrap();
alice alice.into_accepted(&sender, content).expect_err("Didn't cancel on a invalid sender");
.into_accepted(&sender, content)
.expect_err("Didn't cancel on a invalid sender");
} }
#[tokio::test] #[tokio::test]

View File

@ -152,10 +152,7 @@ impl EventBuilder {
} }
fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) { fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) {
self.joined_room_events self.joined_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
} }
pub fn add_custom_invited_event( pub fn add_custom_invited_event(
@ -164,10 +161,7 @@ impl EventBuilder {
event: serde_json::Value, event: serde_json::Value,
) -> &mut Self { ) -> &mut Self {
let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap(); let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap();
self.invited_room_events self.invited_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self self
} }
@ -177,10 +171,7 @@ impl EventBuilder {
event: serde_json::Value, event: serde_json::Value,
) -> &mut Self { ) -> &mut Self {
let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap(); let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap();
self.left_room_events self.left_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self self
} }
@ -350,9 +341,7 @@ impl EventBuilder {
pub fn build_sync_response(&mut self) -> SyncResponse { pub fn build_sync_response(&mut self) -> SyncResponse {
let body = self.build_json_sync_response(); let body = self.build_json_sync_response();
let response = Response::builder() let response = Response::builder().body(serde_json::to_vec(&body).unwrap()).unwrap();
.body(serde_json::to_vec(&body).unwrap())
.unwrap();
SyncResponse::try_from_http_response(response).unwrap() SyncResponse::try_from_http_response(response).unwrap()
} }
@ -393,15 +382,10 @@ pub fn sync_response(kind: SyncResponseFile) -> SyncResponse {
SyncResponseFile::Voip => &test_json::VOIP_SYNC, SyncResponseFile::Voip => &test_json::VOIP_SYNC,
}; };
let response = Response::builder() let response = Response::builder().body(data.to_string().as_bytes().to_vec()).unwrap();
.body(data.to_string().as_bytes().to_vec())
.unwrap();
SyncResponse::try_from_http_response(response).unwrap() SyncResponse::try_from_http_response(response).unwrap()
} }
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> { pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder() Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
} }

View File

@ -1,5 +1,6 @@
comment_width = 100 max_width = 100
comment_width = 80
wrap_comments = true wrap_comments = true
imports_granularity = "Crate" imports_granularity = "Crate"
max_width = 100 use_small_heuristics = "Max"
group_imports = "StdExternalCrate" group_imports = "StdExternalCrate"