crypto: Don't mark outbound group sessions automatically as shared.

master
Damir Jelić 2020-10-01 16:31:24 +02:00
parent fc6ff2c78a
commit 02c765f903
8 changed files with 122 additions and 52 deletions

View File

@ -1303,10 +1303,13 @@ impl Client {
} }
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
async fn send_to_device(&self, request: ToDeviceRequest) -> Result<ToDeviceResponse> { async fn send_to_device(&self, request: &ToDeviceRequest) -> Result<ToDeviceResponse> {
let txn_id_string = request.txn_id_string(); let txn_id_string = request.txn_id_string();
let request = let request = RumaToDeviceRequest::new(
RumaToDeviceRequest::new(request.event_type, &txn_id_string, request.messages); request.event_type.clone(),
&txn_id_string,
request.messages.clone(),
);
self.send(request).await self.send(request).await
} }
@ -1468,7 +1471,8 @@ impl Client {
} }
} }
OutgoingRequests::ToDeviceRequest(request) => { OutgoingRequests::ToDeviceRequest(request) => {
if let Ok(resp) = self.send_to_device(request.clone()).await { // TODO remove this unwrap
if let Ok(resp) = self.send_to_device(&request).await {
self.base_client self.base_client
.mark_request_as_sent(&r.request_id(), &resp) .mark_request_as_sent(&r.request_id(), &resp)
.await .await
@ -1551,7 +1555,11 @@ impl Client {
.expect("Keys don't need to be uploaded"); .expect("Keys don't need to be uploaded");
for request in requests.drain(..) { for request in requests.drain(..) {
self.send_to_device(request).await?; let response = self.send_to_device(&request).await?;
self.base_client
.mark_request_as_sent(&request.txn_id, &response)
.await?;
} }
Ok(()) Ok(())

View File

@ -1304,7 +1304,7 @@ impl BaseClient {
/// Get a to-device request that will share a group session for a room. /// Get a to-device request that will share a group session for a room.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn share_group_session(&self, room_id: &RoomId) -> Result<Vec<ToDeviceRequest>> { pub async fn share_group_session(&self, room_id: &RoomId) -> Result<Vec<Arc<ToDeviceRequest>>> {
let room = self.get_joined_room(room_id).await.expect("No room found"); let room = self.get_joined_room(room_id).await.expect("No room found");
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
@ -1881,18 +1881,19 @@ impl BaseClient {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use serde_json::json;
use std::convert::TryFrom;
#[cfg(feature = "messages")] #[cfg(feature = "messages")]
use crate::{ use crate::{
events::AnySyncRoomEvent, identifiers::event_id, BaseClientConfig, JsonStore, Raw, events::AnySyncRoomEvent, identifiers::event_id, BaseClientConfig, JsonStore, Raw,
}; };
use crate::{ use crate::{BaseClient, Session};
identifiers::{room_id, user_id, RoomId},
BaseClient, Session, use matrix_sdk_common::identifiers::{room_id, user_id, RoomId};
};
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
use matrix_sdk_test::{async_test, test_json, EventBuilder, EventsJson}; use matrix_sdk_test::{async_test, test_json, EventBuilder, EventsJson};
use serde_json::json;
use std::convert::TryFrom;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use tempfile::tempdir; use tempfile::tempdir;

View File

@ -21,6 +21,7 @@ use matrix_sdk_common::{
identifiers::{RoomId, UserId}, identifiers::{RoomId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
use tracing::debug;
use crate::{ use crate::{
error::{EventError, MegolmResult, OlmResult}, error::{EventError, MegolmResult, OlmResult},
@ -57,6 +58,12 @@ impl GroupSessionManager {
self.outbound_group_sessions.remove(room_id).is_some() self.outbound_group_sessions.remove(room_id).is_some()
} }
pub fn mark_request_as_sent(&self, request_id: &Uuid) {
self.outbound_sessions_being_shared
.remove(request_id)
.map(|(_, s)| s.mark_request_as_sent(request_id));
}
/// Get an outbound group session for a room, if one exists. /// Get an outbound group session for a room, if one exists.
/// ///
/// # Arguments /// # Arguments
@ -111,11 +118,10 @@ impl GroupSessionManager {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
users_to_share_with: impl Iterator<Item = &UserId>,
) -> OlmResult<()> { ) -> OlmResult<()> {
let (outbound, inbound) = self let (outbound, inbound) = self
.account .account
.create_group_session_pair(room_id, settings, users_to_share_with) .create_group_session_pair(room_id, settings)
.await .await
.map_err(|_| EventError::UnsupportedAlgorithm)?; .map_err(|_| EventError::UnsupportedAlgorithm)?;
@ -140,8 +146,8 @@ impl GroupSessionManager {
room_id: &RoomId, room_id: &RoomId,
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<ToDeviceRequest>> { ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.create_outbound_group_session(room_id, encryption_settings.into(), users) self.create_outbound_group_session(room_id, encryption_settings.into())
.await?; .await?;
let session = self.outbound_group_sessions.get(room_id).unwrap(); let session = self.outbound_group_sessions.get(room_id).unwrap();
@ -149,15 +155,9 @@ impl GroupSessionManager {
panic!("Session is already shared"); panic!("Session is already shared");
} }
// TODO don't mark the session as shared automatically, only when all
// the requests are done, failure to send these requests will likely end
// up in wedged sessions. We'll need to store the requests and let the
// caller mark them as sent using an UUID.
session.mark_as_shared();
let mut devices: Vec<Device> = Vec::new(); let mut devices: Vec<Device> = Vec::new();
for user_id in session.users_to_share_with() { for user_id in users {
let user_devices = self.store.get_user_devices(&user_id).await?; let user_devices = self.store.get_user_devices(&user_id).await?;
devices.extend(user_devices.devices().filter(|d| !d.is_blacklisted())); devices.extend(user_devices.devices().filter(|d| !d.is_blacklisted()));
} }
@ -193,11 +193,25 @@ impl GroupSessionManager {
let id = Uuid::new_v4(); let id = Uuid::new_v4();
requests.push(ToDeviceRequest { let request = Arc::new(ToDeviceRequest {
event_type: EventType::RoomEncrypted, event_type: EventType::RoomEncrypted,
txn_id: id, txn_id: id,
messages, messages,
}); });
session.add_request(id, request.clone());
self.outbound_sessions_being_shared
.insert(id, session.clone());
requests.push(request);
}
if requests.is_empty() {
debug!(
"Session {} for room {} doesn't need to be shared with anyone, marking as shared",
session.session_id(),
session.room_id()
);
session.mark_as_shared();
} }
Ok(requests) Ok(requests)

View File

@ -323,10 +323,7 @@ impl KeyRequestMachine {
Err(KeyshareDecision::UntrustedDevice) Err(KeyshareDecision::UntrustedDevice)
} }
} else if let Some(outbound) = outbound_session { } else if let Some(outbound) = outbound_session {
if outbound if outbound.is_shared_with(device.user_id(), device.device_id()) {
.shared_with()
.contains(&(device.user_id().to_owned(), device.device_id().to_owned()))
{
Ok(()) Ok(())
} else { } else {
Err(KeyshareDecision::OutboundSessionNotShared) Err(KeyshareDecision::OutboundSessionNotShared)

View File

@ -574,7 +574,7 @@ impl OlmMachine {
room_id: &RoomId, room_id: &RoomId,
) -> OlmResult<()> { ) -> OlmResult<()> {
self.group_session_manager self.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default(), [].iter()) .create_outbound_group_session(room_id, EncryptionSettings::default())
.await .await
} }
@ -645,7 +645,7 @@ impl OlmMachine {
room_id: &RoomId, room_id: &RoomId,
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<ToDeviceRequest>> { ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.group_session_manager self.group_session_manager
.share_group_session(room_id, users, encryption_settings) .share_group_session(room_id, users, encryption_settings)
.await .await
@ -705,6 +705,7 @@ impl OlmMachine {
self.key_request_machine self.key_request_machine
.mark_outgoing_request_as_sent(request_id) .mark_outgoing_request_as_sent(request_id)
.await?; .await?;
self.group_session_manager.mark_request_as_sent(request_id);
Ok(()) Ok(())
} }
@ -1037,6 +1038,7 @@ pub(crate) mod test {
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
sync::Arc,
time::SystemTime, time::SystemTime,
}; };
@ -1103,7 +1105,7 @@ pub(crate) mod test {
get_keys::Response::try_from(data).expect("Can't parse the keys upload response") get_keys::Response::try_from(data).expect("Can't parse the keys upload response")
} }
fn to_device_requests_to_content(requests: Vec<ToDeviceRequest>) -> EncryptedEventContent { fn to_device_requests_to_content(requests: Vec<Arc<ToDeviceRequest>>) -> EncryptedEventContent {
let to_device_request = &requests[0]; let to_device_request = &requests[0];
let content: Raw<EncryptedEventContent> = serde_json::from_str( let content: Raw<EncryptedEventContent> = serde_json::from_str(

View File

@ -878,7 +878,6 @@ impl ReadOnlyAccount {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
users_to_share_with: impl Iterator<Item = &UserId>,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 { if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 {
return Err(()); return Err(());
@ -889,7 +888,6 @@ impl ReadOnlyAccount {
self.identity_keys.clone(), self.identity_keys.clone(),
room_id, room_id,
settings, settings,
users_to_share_with,
); );
let identity_keys = self.identity_keys(); let identity_keys = self.identity_keys();
@ -912,7 +910,7 @@ impl ReadOnlyAccount {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
self.create_group_session_pair(room_id, EncryptionSettings::default(), [].iter()) self.create_group_session_pair(room_id, EncryptionSettings::default())
.await .await
} }

View File

@ -147,7 +147,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into()); let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account let (session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) .create_group_session_pair(&room_id!("!test_room:example.org"), settings)
.await .await
.unwrap(); .unwrap();
@ -165,7 +165,7 @@ mod test {
}; };
let (mut session, _) = account let (mut session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) .create_group_session_pair(&room_id!("!test_room:example.org"), settings)
.await .await
.unwrap(); .unwrap();

View File

@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dashmap::{setref::multiple::RefMulti, DashSet}; use dashmap::{DashMap, DashSet};
use matrix_sdk_common::{api::r0::to_device::DeviceIdOrAllDevices, uuid::Uuid};
use std::{ use std::{
cmp::min, cmp::min,
fmt, fmt,
@ -22,6 +23,7 @@ use std::{
}, },
time::Duration, time::Duration,
}; };
use tracing::debug;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
@ -32,15 +34,17 @@ use matrix_sdk_common::{
instant::Instant, instant::Instant,
locks::Mutex, locks::Mutex,
}; };
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use serde_json::{json, Value}; use serde_json::{json, Value};
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
pub use olm_rs::{ pub use olm_rs::{
account::IdentityKeys, account::IdentityKeys,
session::{OlmMessage, PreKeyMessage}, session::{OlmMessage, PreKeyMessage},
utility::OlmUtility, utility::OlmUtility,
}; };
use crate::ToDeviceRequest;
use super::GroupSessionKey; use super::GroupSessionKey;
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000); const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
@ -101,8 +105,8 @@ pub struct OutboundGroupSession {
message_count: Arc<AtomicU64>, message_count: Arc<AtomicU64>,
shared: Arc<AtomicBool>, shared: Arc<AtomicBool>,
settings: Arc<EncryptionSettings>, settings: Arc<EncryptionSettings>,
shared_with_set: Arc<DashSet<(UserId, DeviceIdBox)>>, shared_with_set: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
to_share_with_set: Arc<DashSet<UserId>>, to_share_with_set: Arc<DashMap<Uuid, Arc<ToDeviceRequest>>>,
} }
impl OutboundGroupSession { impl OutboundGroupSession {
@ -121,18 +125,15 @@ impl OutboundGroupSession {
/// ///
/// * `settings` - Settings determining the algorithm and rotation period of /// * `settings` - Settings determining the algorithm and rotation period of
/// the outbound group session. /// the outbound group session.
pub fn new<'a>( pub fn new(
device_id: Arc<DeviceIdBox>, device_id: Arc<DeviceIdBox>,
identity_keys: Arc<IdentityKeys>, identity_keys: Arc<IdentityKeys>,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
users_to_share_with: impl Iterator<Item = &'a UserId>,
) -> Self { ) -> Self {
let session = OlmOutboundGroupSession::new(); let session = OlmOutboundGroupSession::new();
let session_id = session.session_id(); let session_id = session.session_id();
let users_to_share_with = users_to_share_with.cloned().collect();
OutboundGroupSession { OutboundGroupSession {
inner: Arc::new(Mutex::new(session)), inner: Arc::new(Mutex::new(session)),
room_id: Arc::new(room_id.to_owned()), room_id: Arc::new(room_id.to_owned()),
@ -143,13 +144,52 @@ impl OutboundGroupSession {
message_count: Arc::new(AtomicU64::new(0)), message_count: Arc::new(AtomicU64::new(0)),
shared: Arc::new(AtomicBool::new(false)), shared: Arc::new(AtomicBool::new(false)),
settings: Arc::new(settings), settings: Arc::new(settings),
shared_with_set: Arc::new(DashSet::new()), shared_with_set: Arc::new(DashMap::new()),
to_share_with_set: Arc::new(users_to_share_with), to_share_with_set: Arc::new(DashMap::new()),
} }
} }
pub(crate) fn users_to_share_with(&self) -> impl Iterator<Item = RefMulti<'_, UserId>> + '_ { pub fn add_request(&self, request_id: Uuid, request: Arc<ToDeviceRequest>) {
self.to_share_with_set.iter() self.to_share_with_set.insert(request_id, request);
}
/// Mark the request with the given request id as sent.
///
/// This removes the request from the queue and marks the set of
/// users/devices that received the session.
pub fn mark_request_as_sent(&self, request_id: &Uuid) {
let request = self.to_share_with_set.remove(request_id);
request.map(|(_, r)| {
let user_pairs = r.messages.iter().map(|(u, v)| {
(
u.clone(),
v.keys().filter_map(|d| {
if let DeviceIdOrAllDevices::DeviceId(d) = d {
Some(d.clone())
} else {
None
}
}),
)
});
user_pairs.for_each(|(u, d)| {
self.shared_with_set
.entry(u)
.or_insert_with(DashSet::new)
.extend(d);
})
});
if self.to_share_with_set.is_empty() {
debug!(
"Marking session {} for room {} as shared.",
self.session_id(),
self.room_id
);
self.mark_as_shared();
}
} }
/// Encrypt the given plaintext using this session. /// Encrypt the given plaintext using this session.
@ -246,6 +286,11 @@ impl OutboundGroupSession {
GroupSessionKey(session.session_key()) GroupSessionKey(session.session_key())
} }
/// Get the room id of the room this session belongs to.
pub fn room_id(&self) -> &RoomId {
&self.room_id
}
/// Returns the unique identifier for this session. /// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str { pub fn session_id(&self) -> &str {
&self.session_id &self.session_id
@ -273,15 +318,20 @@ impl OutboundGroupSession {
} }
/// The set of users this session is shared with. /// The set of users this session is shared with.
pub(crate) fn shared_with(&self) -> &DashSet<(UserId, DeviceIdBox)> { pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
&self.shared_with_set self.shared_with_set
.get(user_id)
.map(|d| d.contains(device_id))
.unwrap_or(false)
} }
/// Mark that the session was shared with the given user/device pair. /// Mark that the session was shared with the given user/device pair.
#[allow(dead_code)] #[cfg(test)]
pub fn mark_shared_with(&self, user_id: &UserId, device_id: &DeviceId) { pub fn mark_shared_with(&self, user_id: &UserId, device_id: &DeviceId) {
self.shared_with_set self.shared_with_set
.insert((user_id.to_owned(), device_id.to_owned())); .entry(user_id.to_owned())
.or_insert_with(DashSet::new)
.insert(device_id.to_owned());
} }
} }