crypto: Answer key reshare requests only at the originally shared message index

This commit is contained in:
Damir Jelić 2021-01-28 14:07:51 +01:00
parent bf4f32eccf
commit 10da61c567
7 changed files with 139 additions and 86 deletions

View file

@ -189,8 +189,13 @@ impl Device {
pub async fn encrypt_session(
&self,
session: InboundGroupSession,
message_index: Option<u32>,
) -> OlmResult<(Session, EncryptedEventContent)> {
let export = session.export().await;
let export = if let Some(index) = message_index {
session.export_at_index(index).await
} else {
session.export().await
};
let content: ForwardedRoomKeyToDeviceEventContent = if let Ok(c) = export.try_into() {
c

View file

@ -40,7 +40,7 @@ use matrix_sdk_common::{
use crate::{
error::{OlmError, OlmResult},
olm::{InboundGroupSession, OutboundGroupSession, Session},
olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState},
requests::{OutgoingRequest, ToDeviceRequest},
store::{CryptoStoreError, Store},
Device,
@ -347,42 +347,46 @@ impl KeyRequestMachine {
.await?;
if let Some(device) = device {
if let Err(e) = self.should_share_session(
match self.should_share_session(
&device,
self.outbound_group_sessions
.get(&key_info.room_id)
.as_deref(),
) {
info!(
"Received a key request from {} {} that we won't serve: {}",
device.user_id(),
device.device_id(),
e
);
Err(e) => {
info!(
"Received a key request from {} {} that we won't serve: {}",
device.user_id(),
device.device_id(),
e
);
Ok(None)
} else {
info!(
"Serving a key request for {} from {} {}.",
key_info.session_id,
device.user_id(),
device.device_id()
);
Ok(None)
}
Ok(message_index) => {
info!(
"Serving a key request for {} from {} {} with message_index {:?}.",
key_info.session_id,
device.user_id(),
device.device_id(),
message_index,
);
match self.share_session(&session, &device).await {
Ok(s) => Ok(Some(s)),
Err(OlmError::MissingSession) => {
info!(
"Key request from {} {} is missing an Olm session, \
match self.share_session(&session, &device, message_index).await {
Ok(s) => Ok(Some(s)),
Err(OlmError::MissingSession) => {
info!(
"Key request from {} {} is missing an Olm session, \
putting the request in the wait queue",
device.user_id(),
device.device_id()
);
self.handle_key_share_without_session(device, event);
device.user_id(),
device.device_id()
);
self.handle_key_share_without_session(device, event);
Ok(None)
Ok(None)
}
Err(e) => Err(e),
}
Err(e) => Err(e),
}
}
} else {
@ -400,8 +404,11 @@ impl KeyRequestMachine {
&self,
session: &InboundGroupSession,
device: &Device,
message_index: Option<u32>,
) -> OlmResult<Session> {
let (used_session, content) = device.encrypt_session(session.clone()).await?;
let (used_session, content) = device
.encrypt_session(session.clone(), message_index)
.await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
@ -453,16 +460,18 @@ impl KeyRequestMachine {
&self,
device: &Device,
outbound_session: Option<&OutboundGroupSession>,
) -> Result<(), KeyshareDecision> {
) -> Result<Option<u32>, KeyshareDecision> {
if device.user_id() == self.user_id() {
if device.trust_state() {
Ok(())
Ok(None)
} else {
Err(KeyshareDecision::UntrustedDevice)
}
} else if let Some(outbound) = outbound_session {
if outbound.is_shared_with(device.user_id(), device.device_id()) {
Ok(())
if let ShareState::Shared(message_index) =
outbound.is_shared_with(device.user_id(), device.device_id())
{
Ok(Some(message_index))
} else {
Err(KeyshareDecision::OutboundSessionNotShared)
}
@ -830,7 +839,7 @@ mod test {
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
let export = session.export_at_index(10).await.unwrap();
let export = session.export_at_index(10).await;
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
@ -887,7 +896,7 @@ mod test {
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
let export = session.export_at_index(15).await.unwrap();
let export = session.export_at_index(15).await;
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
@ -903,7 +912,7 @@ mod test {
assert!(second_session.is_none());
let export = session.export_at_index(0).await.unwrap();
let export = session.export_at_index(0).await;
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();

View file

@ -183,9 +183,7 @@ impl InboundGroupSession {
/// If only a limited part of this session should be exported use
/// [`export_at_index()`](#method.export_at_index).
pub async fn export(&self) -> ExportedRoomKey {
self.export_at_index(self.first_known_index())
.await
.expect("Can't export at the first known index")
self.export_at_index(self.first_known_index()).await
}
/// Get the sender key that this session was received from.
@ -194,11 +192,18 @@ impl InboundGroupSession {
}
/// Export this session at the given message index.
pub async fn export_at_index(&self, message_index: u32) -> Option<ExportedRoomKey> {
let session_key =
ExportedGroupSessionKey(self.inner.lock().await.export(message_index).ok()?);
pub async fn export_at_index(&self, message_index: u32) -> ExportedRoomKey {
let message_index = std::cmp::max(self.first_known_index(), message_index);
Some(ExportedRoomKey {
let session_key = ExportedGroupSessionKey(
self.inner
.lock()
.await
.export(message_index)
.expect("Can't export session"),
);
ExportedRoomKey {
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
room_id: (&*self.room_id).clone(),
sender_key: (&*self.sender_key).to_owned(),
@ -212,7 +217,7 @@ impl InboundGroupSession {
.unwrap_or_default(),
sender_claimed_keys: (&*self.signing_key).clone(),
session_key,
})
}
}
/// Restore a Session from a previously pickled string.

View file

@ -24,7 +24,9 @@ mod inbound;
mod outbound;
pub use inbound::{InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession};
pub use outbound::{EncryptionSettings, OutboundGroupSession, PickledOutboundGroupSession};
pub use outbound::{
EncryptionSettings, OutboundGroupSession, PickledOutboundGroupSession, ShareState,
};
/// The private session key of a group session.
/// Can be used to create a new inbound group session.

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use dashmap::{DashMap, DashSet};
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::room::{
@ -23,7 +23,7 @@ use matrix_sdk_common::{
};
use std::{
cmp::max,
collections::{BTreeMap, BTreeSet},
collections::BTreeMap,
fmt,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
@ -64,6 +64,11 @@ use super::{
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
const ROTATION_MESSAGES: u64 = 100;
pub enum ShareState {
NotShared,
Shared(u32),
}
/// Settings for an encrypted room.
///
/// This determines the algorithm and rotation periods of a group session.
@ -120,8 +125,8 @@ pub struct OutboundGroupSession {
shared: Arc<AtomicBool>,
invalidated: Arc<AtomicBool>,
settings: Arc<EncryptionSettings>,
pub(crate) shared_with_set: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
to_share_with_set: Arc<DashMap<Uuid, Arc<ToDeviceRequest>>>,
pub(crate) shared_with_set: Arc<DashMap<UserId, DashMap<DeviceIdBox, u32>>>,
to_share_with_set: Arc<DashMap<Uuid, (Arc<ToDeviceRequest>, u32)>>,
}
impl OutboundGroupSession {
@ -165,8 +170,14 @@ impl OutboundGroupSession {
}
}
pub(crate) fn add_request(&self, request_id: Uuid, request: Arc<ToDeviceRequest>) {
self.to_share_with_set.insert(request_id, request);
pub(crate) fn add_request(
&self,
request_id: Uuid,
request: Arc<ToDeviceRequest>,
message_index: u32,
) {
self.to_share_with_set
.insert(request_id, (request, message_index));
}
/// This should be called if an the user wishes to rotate this session.
@ -180,12 +191,12 @@ impl OutboundGroupSession {
/// users/devices that received the session.
pub fn mark_request_as_sent(&self, request_id: &Uuid) {
if let Some((_, r)) = self.to_share_with_set.remove(request_id) {
let user_pairs = r.messages.iter().map(|(u, v)| {
let user_pairs = r.0.messages.iter().map(|(u, v)| {
(
u.clone(),
v.keys().filter_map(|d| {
if let DeviceIdOrAllDevices::DeviceId(d) = d {
Some(d.clone())
v.iter().filter_map(|d| {
if let DeviceIdOrAllDevices::DeviceId(d) = d.0 {
Some((d.clone(), r.1))
} else {
None
}
@ -196,7 +207,7 @@ impl OutboundGroupSession {
user_pairs.for_each(|(u, d)| {
self.shared_with_set
.entry(u)
.or_insert_with(DashSet::new)
.or_insert_with(DashMap::new)
.extend(d);
});
@ -349,28 +360,40 @@ impl OutboundGroupSession {
}
/// Has or will the session be shared with the given user/device pair.
pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
let shared_with = self
pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> ShareState {
// Check if we shared the session.
let shared_state = self
.shared_with_set
.get(user_id)
.map(|d| d.contains(device_id))
.unwrap_or(false);
.and_then(|d| d.get(device_id).map(|m| ShareState::Shared(*m.value())));
let should_be_shared_with = if self.shared() {
false
if let Some(state) = shared_state {
state
} else {
// If we haven't shared the session, check if we're going to share
// the session.
let device_id = DeviceIdOrAllDevices::DeviceId(device_id.into());
self.to_share_with_set.iter().any(|item| {
if let Some(e) = item.value().messages.get(user_id) {
e.contains_key(&device_id)
} else {
false
}
})
};
// Find the first request that contains the given user id and
// device id.
let shared = self.to_share_with_set.iter().find_map(|item| {
let request = &item.value().0;
let message_index = item.value().1;
shared_with || should_be_shared_with
if request
.messages
.get(user_id)
.map(|e| e.contains_key(&device_id))
.unwrap_or(false)
{
Some(ShareState::Shared(message_index))
} else {
None
}
});
shared.unwrap_or(ShareState::NotShared)
}
}
/// Mark that the session was shared with the given user/device pair.
@ -378,8 +401,8 @@ impl OutboundGroupSession {
pub fn mark_shared_with(&self, user_id: &UserId, device_id: &DeviceId) {
self.shared_with_set
.entry(user_id.to_owned())
.or_insert_with(DashSet::new)
.insert(device_id.to_owned());
.or_insert_with(DashMap::new)
.insert(device_id.to_owned(), 0);
}
/// Get the list of requests that need to be sent out for this session to be
@ -387,7 +410,7 @@ impl OutboundGroupSession {
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
self.to_share_with_set
.iter()
.map(|i| i.value().clone())
.map(|i| i.value().0.clone())
.collect()
}
@ -465,7 +488,10 @@ impl OutboundGroupSession {
(
u.key().clone(),
#[allow(clippy::map_clone)]
u.value().iter().map(|d| d.clone()).collect(),
u.value()
.iter()
.map(|d| (d.key().clone(), *d.value()))
.collect(),
)
})
.collect(),
@ -524,9 +550,9 @@ pub struct PickledOutboundGroupSession {
/// Has the session been invalidated.
pub invalidated: bool,
/// The set of users the session has been already shared with.
pub shared_with_set: BTreeMap<UserId, BTreeSet<DeviceIdBox>>,
pub shared_with_set: BTreeMap<UserId, BTreeMap<DeviceIdBox, u32>>,
/// Requests that need to be sent out to share the session.
pub requests: BTreeMap<Uuid, Arc<ToDeviceRequest>>,
pub requests: BTreeMap<Uuid, (Arc<ToDeviceRequest>, u32)>,
}
#[cfg(test)]

View file

@ -25,11 +25,11 @@ mod utility;
pub(crate) use account::{Account, OlmDecryptionInfo, SessionType};
pub use account::{AccountPickle, OlmMessageHash, PickledAccount, ReadOnlyAccount};
pub(crate) use group_sessions::GroupSessionKey;
pub use group_sessions::{
EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle,
OutboundGroupSession, PickledInboundGroupSession, PickledOutboundGroupSession,
};
pub(crate) use group_sessions::{GroupSessionKey, ShareState};
pub use olm_rs::{account::IdentityKeys, PicklingMode};
pub use session::{PickledSession, Session, SessionPickle};
pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity};

View file

@ -29,7 +29,7 @@ use tracing::{debug, info};
use crate::{
error::{EventError, MegolmResult, OlmResult},
olm::{Account, InboundGroupSession, OutboundGroupSession, Session},
olm::{Account, InboundGroupSession, OutboundGroupSession, Session, ShareState},
store::{Changes, Store},
Device, EncryptionSettings, OlmError, ToDeviceRequest,
};
@ -263,7 +263,8 @@ impl GroupSessionManager {
{
#[allow(clippy::map_clone)]
// Devices that received this session
let shared: HashSet<DeviceIdBox> = shared.iter().map(|d| d.clone()).collect();
let shared: HashSet<DeviceIdBox> =
shared.iter().map(|d| d.key().clone()).collect();
let shared: HashSet<&DeviceId> = shared.iter().map(|d| d.as_ref()).collect();
// The difference between the devices that received the
@ -341,16 +342,23 @@ impl GroupSessionManager {
let devices: Vec<Device> = devices
.into_iter()
.map(|(_, d)| {
d.into_iter()
.filter(|d| !outbound.is_shared_with(d.user_id(), d.device_id()))
d.into_iter().filter(|d| {
matches!(
outbound.is_shared_with(d.user_id(), d.device_id()),
ShareState::NotShared
)
})
})
.flatten()
.collect();
let key_content = outbound.as_json().await;
let message_index = outbound.message_index().await;
if !devices.is_empty() {
info!(
"Sharing outbound session at index {} with {:?}",
outbound.message_index().await,
message_index,
devices.iter().fold(BTreeMap::new(), |mut acc, d| {
acc.entry(d.user_id())
.or_insert_with(BTreeSet::new)
@ -360,15 +368,13 @@ impl GroupSessionManager {
);
}
let key_content = outbound.as_json().await;
for device_map_chunk in devices.chunks(Self::MAX_TO_DEVICE_MESSAGES) {
let (id, request, used_sessions) = self
.encrypt_session_for(key_content.clone(), device_map_chunk)
.await?;
if !request.messages.is_empty() {
outbound.add_request(id, request.into());
outbound.add_request(id, request.into(), message_index);
self.outbound_sessions_being_shared
.insert(id, outbound.clone());
}