crypto: Only load the outgoing key requests when we want to send them out

This commit is contained in:
Damir Jelić 2021-04-15 19:40:24 +02:00
parent f9d290746c
commit 8c007510cd
7 changed files with 198 additions and 105 deletions

View file

@ -1839,7 +1839,16 @@ impl Client {
warn!("Error while claiming one-time keys {:?}", e);
}
for r in self.base_client.outgoing_requests().await {
// TODO we should probably abort if we get an cryptostore error here
let outgoing_requests = match self.base_client.outgoing_requests().await {
Ok(r) => r,
Err(e) => {
warn!("Could not fetch the outgoing requests {:?}", e);
vec![]
}
};
for r in outgoing_requests {
match r.request() {
OutgoingRequests::KeysQuery(request) => {
if let Err(e) = self

View file

@ -1069,12 +1069,12 @@ impl BaseClient {
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
pub async fn outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let olm = self.olm.lock().await;
match &*olm {
Some(o) => o.outgoing_requests().await,
None => vec![],
None => Ok(vec![]),
}
}

View file

@ -150,7 +150,7 @@ pub struct OutgoingKeyRequest {
}
impl OutgoingKeyRequest {
fn into_request(
fn to_request(
&self,
recipient: &UserId,
own_device_id: &DeviceId,
@ -214,30 +214,25 @@ impl KeyRequestMachine {
device_id,
store,
outbound_group_sessions,
outgoing_to_device_requests: Arc::new(DashMap::new()),
incoming_key_requests: Arc::new(DashMap::new()),
outgoing_to_device_requests: DashMap::new().into(),
incoming_key_requests: DashMap::new().into(),
wait_queue: WaitQueue::new(),
users_for_key_claim,
}
}
/// Load stored non-sent out outgoing requests
pub async fn load_outgoing_requests(&mut self) -> Result<(), CryptoStoreError> {
let infos: Vec<OutgoingKeyRequest> = vec![];
let requests: DashMap<Uuid, OutgoingRequest> = infos
.iter()
/// Load stored outgoing requests that were not yet sent out.
async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
self.store
.get_unsent_key_requests()
.await?
.into_iter()
.filter(|i| !i.sent_out)
.filter_map(|info| {
Some((
info.request_id,
info.into_request(self.user_id(), self.device_id()).ok()?,
))
.map(|info| {
info.to_request(self.user_id(), self.device_id())
.map_err(CryptoStoreError::from)
})
.collect();
self.outgoing_to_device_requests = requests.into();
Ok(())
.collect()
}
/// Our own user id.
@ -250,12 +245,18 @@ impl KeyRequestMachine {
&self.device_id
}
pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> {
#[allow(clippy::map_clone)]
self.outgoing_to_device_requests
pub async fn outgoing_to_device_requests(
&self,
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> = self
.outgoing_to_device_requests
.iter()
.map(|r| (*r).clone())
.collect()
.map(|i| i.value().clone())
.collect();
key_requests.extend(key_forwards);
Ok(key_requests)
}
/// Receive a room key request event.
@ -584,10 +585,7 @@ impl KeyRequestMachine {
sent_out: false,
};
let request = info.into_request(self.user_id(), self.device_id())?;
self.save_outgoing_key_info(info).await?;
self.outgoing_to_device_requests.insert(id, request);
Ok(())
}
@ -830,7 +828,11 @@ mod test {
async fn create_machine() {
let machine = get_machine().await;
assert!(machine.outgoing_to_device_requests().is_empty());
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
}
#[async_test]
@ -843,7 +845,11 @@ mod test {
.await
.unwrap();
assert!(machine.outgoing_to_device_requests().is_empty());
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
machine
.create_outgoing_key_request(
session.room_id(),
@ -852,8 +858,15 @@ mod test {
)
.await
.unwrap();
assert!(!machine.outgoing_to_device_requests().is_empty());
assert_eq!(machine.outgoing_to_device_requests().len(), 1);
assert!(!machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert_eq!(
machine.outgoing_to_device_requests().await.unwrap().len(),
1
);
machine
.create_outgoing_key_request(
@ -863,15 +876,21 @@ mod test {
)
.await
.unwrap();
assert_eq!(machine.outgoing_to_device_requests.len(), 1);
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
let requests = machine.outgoing_to_device_requests().await.unwrap();
assert_eq!(requests.len(), 1);
let id = request.request_id;
drop(request);
let request = requests.get(0).unwrap();
machine.mark_outgoing_request_as_sent(id).await.unwrap();
assert!(machine.outgoing_to_device_requests.is_empty());
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
}
#[async_test]
@ -892,9 +911,9 @@ mod test {
.await
.unwrap();
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
let requests = machine.outgoing_to_device_requests().await.unwrap();
let request = requests.get(0).unwrap();
let id = request.request_id;
drop(request);
machine.mark_outgoing_request_as_sent(id).await.unwrap();
@ -949,11 +968,13 @@ mod test {
.await
.unwrap();
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
let id = request.request_id;
drop(request);
let requests = machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0];
machine.mark_outgoing_request_as_sent(id).await.unwrap();
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
let export = session.export_at_index(15).await;
@ -1160,11 +1181,8 @@ mod test {
.insert(group_session.clone());
// Get the request and convert it into a event.
let request = alice_machine
.outgoing_to_device_requests
.iter()
.next()
.unwrap();
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0];
let id = request.request_id;
let content = request
.request
@ -1178,7 +1196,6 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap();
drop(request);
alice_machine
.mark_outgoing_request_as_sent(id)
.await
@ -1199,11 +1216,8 @@ mod test {
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
// Get the request and convert it to a encrypted to-device event.
let request = bob_machine
.outgoing_to_device_requests
.iter()
.next()
.unwrap();
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0];
let id = request.request_id;
let content = request
@ -1217,7 +1231,6 @@ mod test {
.unwrap();
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
drop(request);
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
@ -1326,11 +1339,8 @@ mod test {
.insert(group_session.clone());
// Get the request and convert it into a event.
let request = alice_machine
.outgoing_to_device_requests
.iter()
.next()
.unwrap();
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0];
let id = request.request_id;
let content = request
.request
@ -1344,7 +1354,6 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap();
drop(request);
alice_machine
.mark_outgoing_request_as_sent(id)
.await
@ -1356,7 +1365,11 @@ mod test {
};
// Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests.is_empty());
assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.users_for_key_claim.is_empty());
assert!(bob_machine.wait_queue.is_empty());
@ -1364,7 +1377,11 @@ mod test {
bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob doens't have an outgoing requests since we're lacking a session.
assert!(bob_machine.outgoing_to_device_requests.is_empty());
assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(!bob_machine.users_for_key_claim.is_empty());
assert!(!bob_machine.wait_queue.is_empty());
@ -1384,15 +1401,17 @@ mod test {
assert!(bob_machine.users_for_key_claim.is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob now has an outgoing requests.
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
assert!(!bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.wait_queue.is_empty());
// Get the request and convert it to a encrypted to-device event.
let request = bob_machine
.outgoing_to_device_requests
.iter()
.next()
.unwrap();
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0];
let id = request.request_id;
let content = request
@ -1406,7 +1425,6 @@ mod test {
.unwrap();
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
drop(request);
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {

View file

@ -245,10 +245,9 @@ impl OlmMachine {
}
};
let mut machine = OlmMachine::new_helper(&user_id, device_id, store, account, identity);
machine.key_request_machine.load_outgoing_requests().await?;
Ok(machine)
Ok(OlmMachine::new_helper(
&user_id, device_id, store, account, identity,
))
}
/// Create a new machine with the default crypto store.
@ -295,7 +294,7 @@ impl OlmMachine {
/// machine using [`mark_request_as_sent`].
///
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
let mut requests = Vec::new();
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
@ -320,9 +319,14 @@ impl OlmMachine {
requests.append(&mut self.outgoing_to_device_requests());
requests.append(&mut self.verification_machine.outgoing_room_message_requests());
requests.append(&mut self.key_request_machine.outgoing_to_device_requests());
requests.append(
&mut self
.key_request_machine
.outgoing_to_device_requests()
.await?,
);
requests
Ok(requests)
}
/// Mark the request with the given request id as sent.

View file

@ -268,10 +268,11 @@ impl CryptoStore for MemoryStore {
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
}
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
Ok(self
.outgoing_key_requests
.iter()
.filter(|i| !i.value().sent_out)
.map(|i| i.value().clone())
.collect())
}

View file

@ -450,7 +450,7 @@ pub trait CryptoStore: AsyncTraitDeps {
) -> Result<Option<OutgoingKeyRequest>>;
/// Get all outgoing key requests that we have in the store.
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
/// Delete an outoing key request that we created that matches the given
/// request id.

View file

@ -149,6 +149,7 @@ pub struct SledStore {
outbound_group_sessions: Tree,
outgoing_key_requests: Tree,
unsent_key_requests: Tree,
key_requests_by_info: Tree,
devices: Tree,
@ -215,6 +216,7 @@ impl SledStore {
let identities = db.open_tree("identities")?;
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?;
let unsent_key_requests = db.open_tree("unsent_key_requests")?;
let key_requests_by_info = db.open_tree("key_requests_by_info")?;
let session_cache = SessionStore::new();
@ -240,6 +242,7 @@ impl SledStore {
inbound_group_sessions,
outbound_group_sessions,
outgoing_key_requests,
unsent_key_requests,
key_requests_by_info,
devices,
tracked_users,
@ -376,6 +379,7 @@ impl SledStore {
&self.outbound_group_sessions,
&self.olm_hashes,
&self.outgoing_key_requests,
&self.unsent_key_requests,
&self.key_requests_by_info,
)
.transaction(
@ -389,6 +393,7 @@ impl SledStore {
outbound_sessions,
hashes,
outgoing_key_requests,
unsent_key_requests,
key_requests_by_info,
)| {
if let Some(a) = &account_pickle {
@ -463,11 +468,23 @@ impl SledStore {
key_request.request_id.encode(),
)?;
outgoing_key_requests.insert(
key_request.request_id.encode(),
serde_json::to_vec(&key_request)
.map_err(ConflictableTransactionError::Abort)?,
)?;
let key_request_id = key_request.request_id.encode();
if key_request.sent_out {
unsent_key_requests.remove(key_request_id.clone())?;
outgoing_key_requests.insert(
key_request_id,
serde_json::to_vec(&key_request)
.map_err(ConflictableTransactionError::Abort)?,
)?;
} else {
outgoing_key_requests.remove(key_request_id.clone())?;
unsent_key_requests.insert(
key_request_id,
serde_json::to_vec(&key_request)
.map_err(ConflictableTransactionError::Abort)?,
)?;
}
}
Ok(())
@ -479,6 +496,28 @@ impl SledStore {
Ok(())
}
async fn get_outgoing_key_request_helper(
&self,
id: &[u8],
) -> Result<Option<OutgoingKeyRequest>> {
let request = self
.outgoing_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?;
let request = if request.is_none() {
self.unsent_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?
} else {
request
};
Ok(request)
}
}
#[async_trait]
@ -685,11 +724,9 @@ impl CryptoStore for SledStore {
&self,
request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> {
Ok(self
.outgoing_key_requests
.get(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()?)
let request_id = request_id.encode();
self.get_outgoing_key_request_helper(&request_id).await
}
async fn get_key_request_by_info(
@ -699,19 +736,15 @@ impl CryptoStore for SledStore {
let id = self.key_requests_by_info.get(key_info.encode())?;
if let Some(id) = id {
Ok(self
.outgoing_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?)
self.get_outgoing_key_request_helper(&id).await
} else {
Ok(None)
}
}
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
let requests: Result<Vec<OutgoingKeyRequest>> = self
.outgoing_key_requests
.unsent_key_requests
.iter()
.map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from))
.collect();
@ -720,16 +753,30 @@ impl CryptoStore for SledStore {
}
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
let ret: Result<(), TransactionError<serde_json::Error>> =
(&self.outgoing_key_requests, &self.key_requests_by_info).transaction(
|(outgoing_key_requests, key_requests_by_info)| {
let request: Option<OutgoingKeyRequest> = outgoing_key_requests
let ret: Result<(), TransactionError<serde_json::Error>> = (
&self.outgoing_key_requests,
&self.unsent_key_requests,
&self.key_requests_by_info,
)
.transaction(
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = request {
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
@ -1328,7 +1375,21 @@ mod test {
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
assert_eq!(request, stored_request);
assert!(!store.get_outgoing_key_requests().await.unwrap().is_empty());
assert!(!store.get_unsent_key_requests().await.unwrap().is_empty());
let request = OutgoingKeyRequest {
request_id: id,
info: info.clone(),
sent_out: true,
};
let mut changes = Changes::default();
changes.key_requests.push(request.clone());
store.save_changes(changes).await.unwrap();
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
assert_eq!(Some(request), stored_request);
store.delete_outgoing_key_request(id).await.unwrap();
@ -1337,6 +1398,6 @@ mod test {
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
assert_eq!(None, stored_request);
assert!(store.get_outgoing_key_requests().await.unwrap().is_empty());
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
}
}