crypto: Improve key imports.
This patch changes so key imports load all existing sessions at once instead loading a single session for each session we are importing. It removes the need to lock the session when we check the first known index and exposes the total number of sessions the key export contained.
This commit is contained in:
parent
e20b1efae9
commit
804bd221b2
6 changed files with 61 additions and 31 deletions
|
@ -2042,6 +2042,10 @@ impl Client {
|
|||
/// * `passphrase` - The passphrase that should be used to decrypt the
|
||||
/// exported room keys.
|
||||
///
|
||||
/// Returns a tuple of numbers that represent the number of sessions that
|
||||
/// were imported and the total number of sessions that were found in the
|
||||
/// key export.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method will panic if it isn't run on a Tokio runtime.
|
||||
|
@ -2071,7 +2075,7 @@ impl Client {
|
|||
feature = "docs",
|
||||
doc(cfg(all(encryption, not(target_arch = "wasm32"))))
|
||||
)]
|
||||
pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<usize> {
|
||||
pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<(usize, usize)> {
|
||||
let olm = self
|
||||
.base_client
|
||||
.olm_machine()
|
||||
|
|
|
@ -314,7 +314,7 @@ mod test {
|
|||
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
|
||||
|
||||
assert_eq!(export, decrypted);
|
||||
assert_eq!(machine.import_keys(decrypted).await.unwrap(), 0);
|
||||
assert_eq!(machine.import_keys(decrypted).await.unwrap(), (0, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -637,8 +637,8 @@ impl KeyRequestMachine {
|
|||
// If we have a previous session, check if we have a better version
|
||||
// and store the new one if so.
|
||||
let session = if let Some(old_session) = old_session {
|
||||
let first_old_index = old_session.first_known_index().await;
|
||||
let first_index = session.first_known_index().await;
|
||||
let first_old_index = old_session.first_known_index();
|
||||
let first_index = session.first_known_index();
|
||||
|
||||
if first_old_index > first_index {
|
||||
self.mark_as_done(info).await?;
|
||||
|
@ -855,7 +855,7 @@ mod test {
|
|||
.unwrap();
|
||||
let first_session = first_session.unwrap();
|
||||
|
||||
assert_eq!(first_session.first_known_index().await, 10);
|
||||
assert_eq!(first_session.first_known_index(), 10);
|
||||
|
||||
machine
|
||||
.store
|
||||
|
@ -914,7 +914,7 @@ mod test {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(second_session.unwrap().first_known_index().await, 0);
|
||||
assert_eq!(second_session.unwrap().first_known_index(), 0);
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
|
|
@ -1012,7 +1012,9 @@ impl OlmMachine {
|
|||
/// imported into our store. If we already have a better version of a key
|
||||
/// the key will *not* be imported.
|
||||
///
|
||||
/// Returns the number of sessions that were imported to the store.
|
||||
/// Returns a tuple of numbers that represent the number of sessions that
|
||||
/// were imported and the total number of sessions that were found in the
|
||||
/// key export.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```no_run
|
||||
|
@ -1028,32 +1030,47 @@ impl OlmMachine {
|
|||
/// machine.import_keys(exported_keys).await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
pub async fn import_keys(&self, mut exported_keys: Vec<ExportedRoomKey>) -> StoreResult<usize> {
|
||||
pub async fn import_keys(
|
||||
&self,
|
||||
exported_keys: Vec<ExportedRoomKey>,
|
||||
) -> StoreResult<(usize, usize)> {
|
||||
struct ShallowSessions {
|
||||
inner: BTreeMap<Arc<RoomId>, u32>,
|
||||
}
|
||||
|
||||
impl ShallowSessions {
|
||||
fn has_better_session(&self, session: &InboundGroupSession) -> bool {
|
||||
self.inner
|
||||
.get(&session.room_id)
|
||||
.map(|existing| existing <= &session.first_known_index())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
let mut sessions = Vec::new();
|
||||
|
||||
for key in exported_keys.drain(..) {
|
||||
let existing_sessions = ShallowSessions {
|
||||
inner: self
|
||||
.store
|
||||
.get_inbound_group_sessions()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|s| {
|
||||
let index = s.first_known_index();
|
||||
(s.room_id, index)
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let total_sessions = exported_keys.len();
|
||||
|
||||
for key in exported_keys.into_iter() {
|
||||
let session = InboundGroupSession::from_export(key)?;
|
||||
|
||||
// Only import the session if we didn't have this session or if it's
|
||||
// a better version of the same session, that is the first known
|
||||
// index is lower.
|
||||
// TODO load all sessions so we don't do a thousand small loads.
|
||||
if let Some(existing_session) = self
|
||||
.store
|
||||
.get_inbound_group_session(
|
||||
&session.room_id,
|
||||
&session.sender_key,
|
||||
session.session_id(),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
let first_index = session.first_known_index().await;
|
||||
let existing_index = existing_session.first_known_index().await;
|
||||
|
||||
if first_index < existing_index {
|
||||
sessions.push(session)
|
||||
}
|
||||
} else {
|
||||
if !existing_sessions.has_better_session(&session) {
|
||||
sessions.push(session)
|
||||
}
|
||||
}
|
||||
|
@ -1072,7 +1089,7 @@ impl OlmMachine {
|
|||
num_sessions
|
||||
);
|
||||
|
||||
Ok(num_sessions)
|
||||
Ok((num_sessions, total_sessions))
|
||||
}
|
||||
|
||||
/// Export the keys that match the given predicate.
|
||||
|
|
|
@ -57,6 +57,7 @@ use crate::error::{EventError, MegolmResult};
|
|||
pub struct InboundGroupSession {
|
||||
inner: Arc<Mutex<OlmInboundGroupSession>>,
|
||||
session_id: Arc<str>,
|
||||
first_known_index: u32,
|
||||
pub(crate) sender_key: Arc<str>,
|
||||
pub(crate) signing_key: Arc<BTreeMap<DeviceKeyAlgorithm, String>>,
|
||||
pub(crate) room_id: Arc<RoomId>,
|
||||
|
@ -89,6 +90,7 @@ impl InboundGroupSession {
|
|||
) -> Result<Self, OlmGroupSessionError> {
|
||||
let session = OlmInboundGroupSession::new(&session_key.0)?;
|
||||
let session_id = session.session_id();
|
||||
let first_known_index = session.first_known_index();
|
||||
|
||||
let mut keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
|
||||
keys.insert(DeviceKeyAlgorithm::Ed25519, signing_key.to_owned());
|
||||
|
@ -97,6 +99,7 @@ impl InboundGroupSession {
|
|||
inner: Arc::new(Mutex::new(session)),
|
||||
session_id: session_id.into(),
|
||||
sender_key: sender_key.to_owned().into(),
|
||||
first_known_index,
|
||||
signing_key: Arc::new(keys),
|
||||
room_id: Arc::new(room_id.clone()),
|
||||
forwarding_chains: Arc::new(Mutex::new(None)),
|
||||
|
@ -134,6 +137,7 @@ impl InboundGroupSession {
|
|||
let key = Zeroizing::from(mem::take(&mut content.session_key));
|
||||
|
||||
let session = OlmInboundGroupSession::import(&key)?;
|
||||
let first_known_index = session.first_known_index();
|
||||
let mut forwarding_chains = content.forwarding_curve25519_key_chain.clone();
|
||||
forwarding_chains.push(sender_key.to_owned());
|
||||
|
||||
|
@ -147,6 +151,7 @@ impl InboundGroupSession {
|
|||
inner: Arc::new(Mutex::new(session)),
|
||||
session_id: content.session_id.as_str().into(),
|
||||
sender_key: content.sender_key.as_str().into(),
|
||||
first_known_index,
|
||||
signing_key: Arc::new(sender_claimed_key),
|
||||
room_id: Arc::new(content.room_id.clone()),
|
||||
forwarding_chains: Arc::new(Mutex::new(Some(forwarding_chains))),
|
||||
|
@ -178,7 +183,7 @@ impl InboundGroupSession {
|
|||
/// If only a limited part of this session should be exported use
|
||||
/// [`export_at_index()`](#method.export_at_index).
|
||||
pub async fn export(&self) -> ExportedRoomKey {
|
||||
self.export_at_index(self.first_known_index().await)
|
||||
self.export_at_index(self.first_known_index())
|
||||
.await
|
||||
.expect("Can't export at the first known index")
|
||||
}
|
||||
|
@ -221,12 +226,14 @@ impl InboundGroupSession {
|
|||
pickle_mode: PicklingMode,
|
||||
) -> Result<Self, OlmGroupSessionError> {
|
||||
let session = OlmInboundGroupSession::unpickle(pickle.pickle.0, pickle_mode)?;
|
||||
let first_known_index = session.first_known_index();
|
||||
let session_id = session.session_id();
|
||||
|
||||
Ok(InboundGroupSession {
|
||||
inner: Arc::new(Mutex::new(session)),
|
||||
session_id: session_id.into(),
|
||||
sender_key: pickle.sender_key.into(),
|
||||
first_known_index,
|
||||
signing_key: Arc::new(pickle.signing_key),
|
||||
room_id: Arc::new(pickle.room_id),
|
||||
forwarding_chains: Arc::new(Mutex::new(pickle.forwarding_chains)),
|
||||
|
@ -245,8 +252,8 @@ impl InboundGroupSession {
|
|||
}
|
||||
|
||||
/// Get the first message index we know how to decrypt.
|
||||
pub async fn first_known_index(&self) -> u32 {
|
||||
self.inner.lock().await.first_known_index()
|
||||
pub fn first_known_index(&self) -> u32 {
|
||||
self.first_known_index
|
||||
}
|
||||
|
||||
/// Decrypt the given ciphertext.
|
||||
|
@ -368,6 +375,7 @@ impl TryFrom<ExportedRoomKey> for InboundGroupSession {
|
|||
|
||||
fn try_from(key: ExportedRoomKey) -> Result<Self, Self::Error> {
|
||||
let session = OlmInboundGroupSession::import(&key.session_key.0)?;
|
||||
let first_known_index = session.first_known_index();
|
||||
|
||||
let forwarding_chains = if key.forwarding_curve25519_key_chain.is_empty() {
|
||||
None
|
||||
|
@ -379,6 +387,7 @@ impl TryFrom<ExportedRoomKey> for InboundGroupSession {
|
|||
inner: Arc::new(Mutex::new(session)),
|
||||
session_id: key.session_id.into(),
|
||||
sender_key: key.sender_key.into(),
|
||||
first_known_index,
|
||||
signing_key: Arc::new(key.sender_claimed_keys),
|
||||
room_id: Arc::new(key.room_id),
|
||||
forwarding_chains: Arc::new(Mutex::new(forwarding_chains)),
|
||||
|
|
|
@ -213,7 +213,7 @@ pub(crate) mod test {
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(0, inbound.first_known_index().await);
|
||||
assert_eq!(0, inbound.first_known_index());
|
||||
|
||||
assert_eq!(outbound.session_id(), inbound.session_id());
|
||||
|
||||
|
|
Loading…
Reference in a new issue