crypto: Only load the outgoing key requests when we want to send them out

master
Damir Jelić 2021-04-15 19:40:24 +02:00
parent f9d290746c
commit 8c007510cd
7 changed files with 198 additions and 105 deletions

View File

@ -1839,7 +1839,16 @@ impl Client {
warn!("Error while claiming one-time keys {:?}", e); warn!("Error while claiming one-time keys {:?}", e);
} }
for r in self.base_client.outgoing_requests().await { // TODO we should probably abort if we get an cryptostore error here
let outgoing_requests = match self.base_client.outgoing_requests().await {
Ok(r) => r,
Err(e) => {
warn!("Could not fetch the outgoing requests {:?}", e);
vec![]
}
};
for r in outgoing_requests {
match r.request() { match r.request() {
OutgoingRequests::KeysQuery(request) => { OutgoingRequests::KeysQuery(request) => {
if let Err(e) = self if let Err(e) = self

View File

@ -1069,12 +1069,12 @@ impl BaseClient {
/// [`mark_request_as_sent`]: #method.mark_request_as_sent /// [`mark_request_as_sent`]: #method.mark_request_as_sent
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> { pub async fn outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &*olm { match &*olm {
Some(o) => o.outgoing_requests().await, Some(o) => o.outgoing_requests().await,
None => vec![], None => Ok(vec![]),
} }
} }

View File

@ -150,7 +150,7 @@ pub struct OutgoingKeyRequest {
} }
impl OutgoingKeyRequest { impl OutgoingKeyRequest {
fn into_request( fn to_request(
&self, &self,
recipient: &UserId, recipient: &UserId,
own_device_id: &DeviceId, own_device_id: &DeviceId,
@ -214,30 +214,25 @@ impl KeyRequestMachine {
device_id, device_id,
store, store,
outbound_group_sessions, outbound_group_sessions,
outgoing_to_device_requests: Arc::new(DashMap::new()), outgoing_to_device_requests: DashMap::new().into(),
incoming_key_requests: Arc::new(DashMap::new()), incoming_key_requests: DashMap::new().into(),
wait_queue: WaitQueue::new(), wait_queue: WaitQueue::new(),
users_for_key_claim, users_for_key_claim,
} }
} }
/// Load stored non-sent out outgoing requests /// Load stored outgoing requests that were not yet sent out.
pub async fn load_outgoing_requests(&mut self) -> Result<(), CryptoStoreError> { async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let infos: Vec<OutgoingKeyRequest> = vec![]; self.store
let requests: DashMap<Uuid, OutgoingRequest> = infos .get_unsent_key_requests()
.iter() .await?
.into_iter()
.filter(|i| !i.sent_out) .filter(|i| !i.sent_out)
.filter_map(|info| { .map(|info| {
Some(( info.to_request(self.user_id(), self.device_id())
info.request_id, .map_err(CryptoStoreError::from)
info.into_request(self.user_id(), self.device_id()).ok()?,
))
}) })
.collect(); .collect()
self.outgoing_to_device_requests = requests.into();
Ok(())
} }
/// Our own user id. /// Our own user id.
@ -250,12 +245,18 @@ impl KeyRequestMachine {
&self.device_id &self.device_id
} }
pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> { pub async fn outgoing_to_device_requests(
#[allow(clippy::map_clone)] &self,
self.outgoing_to_device_requests ) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> = self
.outgoing_to_device_requests
.iter() .iter()
.map(|r| (*r).clone()) .map(|i| i.value().clone())
.collect() .collect();
key_requests.extend(key_forwards);
Ok(key_requests)
} }
/// Receive a room key request event. /// Receive a room key request event.
@ -584,10 +585,7 @@ impl KeyRequestMachine {
sent_out: false, sent_out: false,
}; };
let request = info.into_request(self.user_id(), self.device_id())?;
self.save_outgoing_key_info(info).await?; self.save_outgoing_key_info(info).await?;
self.outgoing_to_device_requests.insert(id, request);
Ok(()) Ok(())
} }
@ -830,7 +828,11 @@ mod test {
async fn create_machine() { async fn create_machine() {
let machine = get_machine().await; let machine = get_machine().await;
assert!(machine.outgoing_to_device_requests().is_empty()); assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
} }
#[async_test] #[async_test]
@ -843,7 +845,11 @@ mod test {
.await .await
.unwrap(); .unwrap();
assert!(machine.outgoing_to_device_requests().is_empty()); assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
session.room_id(), session.room_id(),
@ -852,8 +858,15 @@ mod test {
) )
.await .await
.unwrap(); .unwrap();
assert!(!machine.outgoing_to_device_requests().is_empty()); assert!(!machine
assert_eq!(machine.outgoing_to_device_requests().len(), 1); .outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert_eq!(
machine.outgoing_to_device_requests().await.unwrap().len(),
1
);
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
@ -863,15 +876,21 @@ mod test {
) )
.await .await
.unwrap(); .unwrap();
assert_eq!(machine.outgoing_to_device_requests.len(), 1);
let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let requests = machine.outgoing_to_device_requests().await.unwrap();
assert_eq!(requests.len(), 1);
let id = request.request_id; let request = requests.get(0).unwrap();
drop(request);
machine.mark_outgoing_request_as_sent(id).await.unwrap(); machine
assert!(machine.outgoing_to_device_requests.is_empty()); .mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
} }
#[async_test] #[async_test]
@ -892,9 +911,9 @@ mod test {
.await .await
.unwrap(); .unwrap();
let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let requests = machine.outgoing_to_device_requests().await.unwrap();
let request = requests.get(0).unwrap();
let id = request.request_id; let id = request.request_id;
drop(request);
machine.mark_outgoing_request_as_sent(id).await.unwrap(); machine.mark_outgoing_request_as_sent(id).await.unwrap();
@ -949,11 +968,13 @@ mod test {
.await .await
.unwrap(); .unwrap();
let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let requests = machine.outgoing_to_device_requests().await.unwrap();
let id = request.request_id; let request = &requests[0];
drop(request);
machine.mark_outgoing_request_as_sent(id).await.unwrap(); machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
let export = session.export_at_index(15).await; let export = session.export_at_index(15).await;
@ -1160,11 +1181,8 @@ mod test {
.insert(group_session.clone()); .insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let request = alice_machine let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
.outgoing_to_device_requests let request = &requests[0];
.iter()
.next()
.unwrap();
let id = request.request_id; let id = request.request_id;
let content = request let content = request
.request .request
@ -1178,7 +1196,6 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap(); serde_json::from_str(content.get()).unwrap();
drop(request);
alice_machine alice_machine
.mark_outgoing_request_as_sent(id) .mark_outgoing_request_as_sent(id)
.await .await
@ -1199,11 +1216,8 @@ mod test {
assert!(!bob_machine.outgoing_to_device_requests.is_empty()); assert!(!bob_machine.outgoing_to_device_requests.is_empty());
// Get the request and convert it to a encrypted to-device event. // Get the request and convert it to a encrypted to-device event.
let request = bob_machine let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
.outgoing_to_device_requests let request = &requests[0];
.iter()
.next()
.unwrap();
let id = request.request_id; let id = request.request_id;
let content = request let content = request
@ -1217,7 +1231,6 @@ mod test {
.unwrap(); .unwrap();
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
drop(request);
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent {
@ -1326,11 +1339,8 @@ mod test {
.insert(group_session.clone()); .insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let request = alice_machine let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
.outgoing_to_device_requests let request = &requests[0];
.iter()
.next()
.unwrap();
let id = request.request_id; let id = request.request_id;
let content = request let content = request
.request .request
@ -1344,7 +1354,6 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap(); serde_json::from_str(content.get()).unwrap();
drop(request);
alice_machine alice_machine
.mark_outgoing_request_as_sent(id) .mark_outgoing_request_as_sent(id)
.await .await
@ -1356,7 +1365,11 @@ mod test {
}; };
// Bob doesn't have any outgoing requests. // Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests.is_empty()); assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.users_for_key_claim.is_empty());
assert!(bob_machine.wait_queue.is_empty()); assert!(bob_machine.wait_queue.is_empty());
@ -1364,7 +1377,11 @@ mod test {
bob_machine.receive_incoming_key_request(&event); bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob doens't have an outgoing requests since we're lacking a session. // Bob doens't have an outgoing requests since we're lacking a session.
assert!(bob_machine.outgoing_to_device_requests.is_empty()); assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(!bob_machine.users_for_key_claim.is_empty()); assert!(!bob_machine.users_for_key_claim.is_empty());
assert!(!bob_machine.wait_queue.is_empty()); assert!(!bob_machine.wait_queue.is_empty());
@ -1384,15 +1401,17 @@ mod test {
assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.users_for_key_claim.is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob now has an outgoing requests. // Bob now has an outgoing requests.
assert!(!bob_machine.outgoing_to_device_requests.is_empty()); assert!(!bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.wait_queue.is_empty()); assert!(bob_machine.wait_queue.is_empty());
// Get the request and convert it to a encrypted to-device event. // Get the request and convert it to a encrypted to-device event.
let request = bob_machine let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
.outgoing_to_device_requests
.iter() let request = &requests[0];
.next()
.unwrap();
let id = request.request_id; let id = request.request_id;
let content = request let content = request
@ -1406,7 +1425,6 @@ mod test {
.unwrap(); .unwrap();
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
drop(request);
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent {

View File

@ -245,10 +245,9 @@ impl OlmMachine {
} }
}; };
let mut machine = OlmMachine::new_helper(&user_id, device_id, store, account, identity); Ok(OlmMachine::new_helper(
machine.key_request_machine.load_outgoing_requests().await?; &user_id, device_id, store, account, identity,
))
Ok(machine)
} }
/// Create a new machine with the default crypto store. /// Create a new machine with the default crypto store.
@ -295,7 +294,7 @@ impl OlmMachine {
/// machine using [`mark_request_as_sent`]. /// machine using [`mark_request_as_sent`].
/// ///
/// [`mark_request_as_sent`]: #method.mark_request_as_sent /// [`mark_request_as_sent`]: #method.mark_request_as_sent
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> { pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
let mut requests = Vec::new(); let mut requests = Vec::new();
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
@ -320,9 +319,14 @@ impl OlmMachine {
requests.append(&mut self.outgoing_to_device_requests()); requests.append(&mut self.outgoing_to_device_requests());
requests.append(&mut self.verification_machine.outgoing_room_message_requests()); requests.append(&mut self.verification_machine.outgoing_room_message_requests());
requests.append(&mut self.key_request_machine.outgoing_to_device_requests()); requests.append(
&mut self
.key_request_machine
.outgoing_to_device_requests()
.await?,
);
requests Ok(requests)
} }
/// Mark the request with the given request id as sent. /// Mark the request with the given request id as sent.

View File

@ -268,10 +268,11 @@ impl CryptoStore for MemoryStore {
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
} }
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> { async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
Ok(self Ok(self
.outgoing_key_requests .outgoing_key_requests
.iter() .iter()
.filter(|i| !i.value().sent_out)
.map(|i| i.value().clone()) .map(|i| i.value().clone())
.collect()) .collect())
} }

View File

@ -450,7 +450,7 @@ pub trait CryptoStore: AsyncTraitDeps {
) -> Result<Option<OutgoingKeyRequest>>; ) -> Result<Option<OutgoingKeyRequest>>;
/// Get all outgoing key requests that we have in the store. /// Get all outgoing key requests that we have in the store.
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>; async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
/// Delete an outoing key request that we created that matches the given /// Delete an outoing key request that we created that matches the given
/// request id. /// request id.

View File

@ -149,6 +149,7 @@ pub struct SledStore {
outbound_group_sessions: Tree, outbound_group_sessions: Tree,
outgoing_key_requests: Tree, outgoing_key_requests: Tree,
unsent_key_requests: Tree,
key_requests_by_info: Tree, key_requests_by_info: Tree,
devices: Tree, devices: Tree,
@ -215,6 +216,7 @@ impl SledStore {
let identities = db.open_tree("identities")?; let identities = db.open_tree("identities")?;
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; let outgoing_key_requests = db.open_tree("outgoing_key_requests")?;
let unsent_key_requests = db.open_tree("unsent_key_requests")?;
let key_requests_by_info = db.open_tree("key_requests_by_info")?; let key_requests_by_info = db.open_tree("key_requests_by_info")?;
let session_cache = SessionStore::new(); let session_cache = SessionStore::new();
@ -240,6 +242,7 @@ impl SledStore {
inbound_group_sessions, inbound_group_sessions,
outbound_group_sessions, outbound_group_sessions,
outgoing_key_requests, outgoing_key_requests,
unsent_key_requests,
key_requests_by_info, key_requests_by_info,
devices, devices,
tracked_users, tracked_users,
@ -376,6 +379,7 @@ impl SledStore {
&self.outbound_group_sessions, &self.outbound_group_sessions,
&self.olm_hashes, &self.olm_hashes,
&self.outgoing_key_requests, &self.outgoing_key_requests,
&self.unsent_key_requests,
&self.key_requests_by_info, &self.key_requests_by_info,
) )
.transaction( .transaction(
@ -389,6 +393,7 @@ impl SledStore {
outbound_sessions, outbound_sessions,
hashes, hashes,
outgoing_key_requests, outgoing_key_requests,
unsent_key_requests,
key_requests_by_info, key_requests_by_info,
)| { )| {
if let Some(a) = &account_pickle { if let Some(a) = &account_pickle {
@ -463,11 +468,23 @@ impl SledStore {
key_request.request_id.encode(), key_request.request_id.encode(),
)?; )?;
outgoing_key_requests.insert( let key_request_id = key_request.request_id.encode();
key_request.request_id.encode(),
serde_json::to_vec(&key_request) if key_request.sent_out {
.map_err(ConflictableTransactionError::Abort)?, unsent_key_requests.remove(key_request_id.clone())?;
)?; outgoing_key_requests.insert(
key_request_id,
serde_json::to_vec(&key_request)
.map_err(ConflictableTransactionError::Abort)?,
)?;
} else {
outgoing_key_requests.remove(key_request_id.clone())?;
unsent_key_requests.insert(
key_request_id,
serde_json::to_vec(&key_request)
.map_err(ConflictableTransactionError::Abort)?,
)?;
}
} }
Ok(()) Ok(())
@ -479,6 +496,28 @@ impl SledStore {
Ok(()) Ok(())
} }
async fn get_outgoing_key_request_helper(
&self,
id: &[u8],
) -> Result<Option<OutgoingKeyRequest>> {
let request = self
.outgoing_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?;
let request = if request.is_none() {
self.unsent_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?
} else {
request
};
Ok(request)
}
} }
#[async_trait] #[async_trait]
@ -685,11 +724,9 @@ impl CryptoStore for SledStore {
&self, &self,
request_id: Uuid, request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
Ok(self let request_id = request_id.encode();
.outgoing_key_requests
.get(request_id.encode())? self.get_outgoing_key_request_helper(&request_id).await
.map(|r| serde_json::from_slice(&r))
.transpose()?)
} }
async fn get_key_request_by_info( async fn get_key_request_by_info(
@ -699,19 +736,15 @@ impl CryptoStore for SledStore {
let id = self.key_requests_by_info.get(key_info.encode())?; let id = self.key_requests_by_info.get(key_info.encode())?;
if let Some(id) = id { if let Some(id) = id {
Ok(self self.get_outgoing_key_request_helper(&id).await
.outgoing_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?)
} else { } else {
Ok(None) Ok(None)
} }
} }
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> { async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
let requests: Result<Vec<OutgoingKeyRequest>> = self let requests: Result<Vec<OutgoingKeyRequest>> = self
.outgoing_key_requests .unsent_key_requests
.iter() .iter()
.map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from)) .map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from))
.collect(); .collect();
@ -720,16 +753,30 @@ impl CryptoStore for SledStore {
} }
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
let ret: Result<(), TransactionError<serde_json::Error>> = let ret: Result<(), TransactionError<serde_json::Error>> = (
(&self.outgoing_key_requests, &self.key_requests_by_info).transaction( &self.outgoing_key_requests,
|(outgoing_key_requests, key_requests_by_info)| { &self.unsent_key_requests,
let request: Option<OutgoingKeyRequest> = outgoing_key_requests &self.key_requests_by_info,
)
.transaction(
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.remove(request_id.encode())? .remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r)) .map(|r| serde_json::from_slice(&r))
.transpose() .transpose()
.map_err(ConflictableTransactionError::Abort)?; .map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = request { let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?; key_requests_by_info.remove((&request.info).encode())?;
} }
@ -1328,7 +1375,21 @@ mod test {
let stored_request = store.get_key_request_by_info(&info).await.unwrap(); let stored_request = store.get_key_request_by_info(&info).await.unwrap();
assert_eq!(request, stored_request); assert_eq!(request, stored_request);
assert!(!store.get_outgoing_key_requests().await.unwrap().is_empty()); assert!(!store.get_unsent_key_requests().await.unwrap().is_empty());
let request = OutgoingKeyRequest {
request_id: id,
info: info.clone(),
sent_out: true,
};
let mut changes = Changes::default();
changes.key_requests.push(request.clone());
store.save_changes(changes).await.unwrap();
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
assert_eq!(Some(request), stored_request);
store.delete_outgoing_key_request(id).await.unwrap(); store.delete_outgoing_key_request(id).await.unwrap();
@ -1337,6 +1398,6 @@ mod test {
let stored_request = store.get_key_request_by_info(&info).await.unwrap(); let stored_request = store.get_key_request_by_info(&info).await.unwrap();
assert_eq!(None, stored_request); assert_eq!(None, stored_request);
assert!(store.get_outgoing_key_requests().await.unwrap().is_empty()); assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
} }
} }