crypto: Only load the outgoing key requests when we want to send them out
parent
f9d290746c
commit
8c007510cd
|
@ -1839,7 +1839,16 @@ impl Client {
|
||||||
warn!("Error while claiming one-time keys {:?}", e);
|
warn!("Error while claiming one-time keys {:?}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
for r in self.base_client.outgoing_requests().await {
|
// TODO we should probably abort if we get an cryptostore error here
|
||||||
|
let outgoing_requests = match self.base_client.outgoing_requests().await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Could not fetch the outgoing requests {:?}", e);
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for r in outgoing_requests {
|
||||||
match r.request() {
|
match r.request() {
|
||||||
OutgoingRequests::KeysQuery(request) => {
|
OutgoingRequests::KeysQuery(request) => {
|
||||||
if let Err(e) = self
|
if let Err(e) = self
|
||||||
|
|
|
@ -1069,12 +1069,12 @@ impl BaseClient {
|
||||||
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||||
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
pub async fn outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||||
let olm = self.olm.lock().await;
|
let olm = self.olm.lock().await;
|
||||||
|
|
||||||
match &*olm {
|
match &*olm {
|
||||||
Some(o) => o.outgoing_requests().await,
|
Some(o) => o.outgoing_requests().await,
|
||||||
None => vec![],
|
None => Ok(vec![]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -150,7 +150,7 @@ pub struct OutgoingKeyRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OutgoingKeyRequest {
|
impl OutgoingKeyRequest {
|
||||||
fn into_request(
|
fn to_request(
|
||||||
&self,
|
&self,
|
||||||
recipient: &UserId,
|
recipient: &UserId,
|
||||||
own_device_id: &DeviceId,
|
own_device_id: &DeviceId,
|
||||||
|
@ -214,30 +214,25 @@ impl KeyRequestMachine {
|
||||||
device_id,
|
device_id,
|
||||||
store,
|
store,
|
||||||
outbound_group_sessions,
|
outbound_group_sessions,
|
||||||
outgoing_to_device_requests: Arc::new(DashMap::new()),
|
outgoing_to_device_requests: DashMap::new().into(),
|
||||||
incoming_key_requests: Arc::new(DashMap::new()),
|
incoming_key_requests: DashMap::new().into(),
|
||||||
wait_queue: WaitQueue::new(),
|
wait_queue: WaitQueue::new(),
|
||||||
users_for_key_claim,
|
users_for_key_claim,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load stored non-sent out outgoing requests
|
/// Load stored outgoing requests that were not yet sent out.
|
||||||
pub async fn load_outgoing_requests(&mut self) -> Result<(), CryptoStoreError> {
|
async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||||
let infos: Vec<OutgoingKeyRequest> = vec![];
|
self.store
|
||||||
let requests: DashMap<Uuid, OutgoingRequest> = infos
|
.get_unsent_key_requests()
|
||||||
.iter()
|
.await?
|
||||||
|
.into_iter()
|
||||||
.filter(|i| !i.sent_out)
|
.filter(|i| !i.sent_out)
|
||||||
.filter_map(|info| {
|
.map(|info| {
|
||||||
Some((
|
info.to_request(self.user_id(), self.device_id())
|
||||||
info.request_id,
|
.map_err(CryptoStoreError::from)
|
||||||
info.into_request(self.user_id(), self.device_id()).ok()?,
|
|
||||||
))
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect()
|
||||||
|
|
||||||
self.outgoing_to_device_requests = requests.into();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Our own user id.
|
/// Our own user id.
|
||||||
|
@ -250,12 +245,18 @@ impl KeyRequestMachine {
|
||||||
&self.device_id
|
&self.device_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> {
|
pub async fn outgoing_to_device_requests(
|
||||||
#[allow(clippy::map_clone)]
|
&self,
|
||||||
self.outgoing_to_device_requests
|
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||||
|
let mut key_requests = self.load_outgoing_requests().await?;
|
||||||
|
let key_forwards: Vec<OutgoingRequest> = self
|
||||||
|
.outgoing_to_device_requests
|
||||||
.iter()
|
.iter()
|
||||||
.map(|r| (*r).clone())
|
.map(|i| i.value().clone())
|
||||||
.collect()
|
.collect();
|
||||||
|
key_requests.extend(key_forwards);
|
||||||
|
|
||||||
|
Ok(key_requests)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Receive a room key request event.
|
/// Receive a room key request event.
|
||||||
|
@ -584,10 +585,7 @@ impl KeyRequestMachine {
|
||||||
sent_out: false,
|
sent_out: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = info.into_request(self.user_id(), self.device_id())?;
|
|
||||||
|
|
||||||
self.save_outgoing_key_info(info).await?;
|
self.save_outgoing_key_info(info).await?;
|
||||||
self.outgoing_to_device_requests.insert(id, request);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -830,7 +828,11 @@ mod test {
|
||||||
async fn create_machine() {
|
async fn create_machine() {
|
||||||
let machine = get_machine().await;
|
let machine = get_machine().await;
|
||||||
|
|
||||||
assert!(machine.outgoing_to_device_requests().is_empty());
|
assert!(machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_test]
|
#[async_test]
|
||||||
|
@ -843,7 +845,11 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(machine.outgoing_to_device_requests().is_empty());
|
assert!(machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
machine
|
machine
|
||||||
.create_outgoing_key_request(
|
.create_outgoing_key_request(
|
||||||
session.room_id(),
|
session.room_id(),
|
||||||
|
@ -852,8 +858,15 @@ mod test {
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(!machine.outgoing_to_device_requests().is_empty());
|
assert!(!machine
|
||||||
assert_eq!(machine.outgoing_to_device_requests().len(), 1);
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
|
assert_eq!(
|
||||||
|
machine.outgoing_to_device_requests().await.unwrap().len(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.create_outgoing_key_request(
|
.create_outgoing_key_request(
|
||||||
|
@ -863,15 +876,21 @@ mod test {
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(machine.outgoing_to_device_requests.len(), 1);
|
|
||||||
|
|
||||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
let requests = machine.outgoing_to_device_requests().await.unwrap();
|
||||||
|
assert_eq!(requests.len(), 1);
|
||||||
|
|
||||||
let id = request.request_id;
|
let request = requests.get(0).unwrap();
|
||||||
drop(request);
|
|
||||||
|
|
||||||
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
machine
|
||||||
assert!(machine.outgoing_to_device_requests.is_empty());
|
.mark_outgoing_request_as_sent(request.request_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_test]
|
#[async_test]
|
||||||
|
@ -892,9 +911,9 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
let requests = machine.outgoing_to_device_requests().await.unwrap();
|
||||||
|
let request = requests.get(0).unwrap();
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
drop(request);
|
|
||||||
|
|
||||||
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||||
|
|
||||||
|
@ -949,11 +968,13 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
let requests = machine.outgoing_to_device_requests().await.unwrap();
|
||||||
let id = request.request_id;
|
let request = &requests[0];
|
||||||
drop(request);
|
|
||||||
|
|
||||||
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
machine
|
||||||
|
.mark_outgoing_request_as_sent(request.request_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let export = session.export_at_index(15).await;
|
let export = session.export_at_index(15).await;
|
||||||
|
|
||||||
|
@ -1160,11 +1181,8 @@ mod test {
|
||||||
.insert(group_session.clone());
|
.insert(group_session.clone());
|
||||||
|
|
||||||
// Get the request and convert it into a event.
|
// Get the request and convert it into a event.
|
||||||
let request = alice_machine
|
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
let request = &requests[0];
|
||||||
.iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
.request
|
.request
|
||||||
|
@ -1178,7 +1196,6 @@ mod test {
|
||||||
let content: RoomKeyRequestToDeviceEventContent =
|
let content: RoomKeyRequestToDeviceEventContent =
|
||||||
serde_json::from_str(content.get()).unwrap();
|
serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
|
||||||
alice_machine
|
alice_machine
|
||||||
.mark_outgoing_request_as_sent(id)
|
.mark_outgoing_request_as_sent(id)
|
||||||
.await
|
.await
|
||||||
|
@ -1199,11 +1216,8 @@ mod test {
|
||||||
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
|
||||||
|
|
||||||
// Get the request and convert it to a encrypted to-device event.
|
// Get the request and convert it to a encrypted to-device event.
|
||||||
let request = bob_machine
|
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
let request = &requests[0];
|
||||||
.iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
|
@ -1217,7 +1231,6 @@ mod test {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
|
||||||
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||||
|
|
||||||
let event = ToDeviceEvent {
|
let event = ToDeviceEvent {
|
||||||
|
@ -1326,11 +1339,8 @@ mod test {
|
||||||
.insert(group_session.clone());
|
.insert(group_session.clone());
|
||||||
|
|
||||||
// Get the request and convert it into a event.
|
// Get the request and convert it into a event.
|
||||||
let request = alice_machine
|
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
let request = &requests[0];
|
||||||
.iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
.request
|
.request
|
||||||
|
@ -1344,7 +1354,6 @@ mod test {
|
||||||
let content: RoomKeyRequestToDeviceEventContent =
|
let content: RoomKeyRequestToDeviceEventContent =
|
||||||
serde_json::from_str(content.get()).unwrap();
|
serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
|
||||||
alice_machine
|
alice_machine
|
||||||
.mark_outgoing_request_as_sent(id)
|
.mark_outgoing_request_as_sent(id)
|
||||||
.await
|
.await
|
||||||
|
@ -1356,7 +1365,11 @@ mod test {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Bob doesn't have any outgoing requests.
|
// Bob doesn't have any outgoing requests.
|
||||||
assert!(bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(bob_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||||
assert!(bob_machine.wait_queue.is_empty());
|
assert!(bob_machine.wait_queue.is_empty());
|
||||||
|
|
||||||
|
@ -1364,7 +1377,11 @@ mod test {
|
||||||
bob_machine.receive_incoming_key_request(&event);
|
bob_machine.receive_incoming_key_request(&event);
|
||||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||||
// Bob doens't have an outgoing requests since we're lacking a session.
|
// Bob doens't have an outgoing requests since we're lacking a session.
|
||||||
assert!(bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(bob_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
assert!(!bob_machine.users_for_key_claim.is_empty());
|
assert!(!bob_machine.users_for_key_claim.is_empty());
|
||||||
assert!(!bob_machine.wait_queue.is_empty());
|
assert!(!bob_machine.wait_queue.is_empty());
|
||||||
|
|
||||||
|
@ -1384,15 +1401,17 @@ mod test {
|
||||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||||
// Bob now has an outgoing requests.
|
// Bob now has an outgoing requests.
|
||||||
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(!bob_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
assert!(bob_machine.wait_queue.is_empty());
|
assert!(bob_machine.wait_queue.is_empty());
|
||||||
|
|
||||||
// Get the request and convert it to a encrypted to-device event.
|
// Get the request and convert it to a encrypted to-device event.
|
||||||
let request = bob_machine
|
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
|
||||||
.iter()
|
let request = &requests[0];
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
|
@ -1406,7 +1425,6 @@ mod test {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
|
||||||
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||||
|
|
||||||
let event = ToDeviceEvent {
|
let event = ToDeviceEvent {
|
||||||
|
|
|
@ -245,10 +245,9 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut machine = OlmMachine::new_helper(&user_id, device_id, store, account, identity);
|
Ok(OlmMachine::new_helper(
|
||||||
machine.key_request_machine.load_outgoing_requests().await?;
|
&user_id, device_id, store, account, identity,
|
||||||
|
))
|
||||||
Ok(machine)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new machine with the default crypto store.
|
/// Create a new machine with the default crypto store.
|
||||||
|
@ -295,7 +294,7 @@ impl OlmMachine {
|
||||||
/// machine using [`mark_request_as_sent`].
|
/// machine using [`mark_request_as_sent`].
|
||||||
///
|
///
|
||||||
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
||||||
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
|
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
|
||||||
|
@ -320,9 +319,14 @@ impl OlmMachine {
|
||||||
|
|
||||||
requests.append(&mut self.outgoing_to_device_requests());
|
requests.append(&mut self.outgoing_to_device_requests());
|
||||||
requests.append(&mut self.verification_machine.outgoing_room_message_requests());
|
requests.append(&mut self.verification_machine.outgoing_room_message_requests());
|
||||||
requests.append(&mut self.key_request_machine.outgoing_to_device_requests());
|
requests.append(
|
||||||
|
&mut self
|
||||||
|
.key_request_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await?,
|
||||||
|
);
|
||||||
|
|
||||||
requests
|
Ok(requests)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark the request with the given request id as sent.
|
/// Mark the request with the given request id as sent.
|
||||||
|
|
|
@ -268,10 +268,11 @@ impl CryptoStore for MemoryStore {
|
||||||
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
|
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.outgoing_key_requests
|
.outgoing_key_requests
|
||||||
.iter()
|
.iter()
|
||||||
|
.filter(|i| !i.value().sent_out)
|
||||||
.map(|i| i.value().clone())
|
.map(|i| i.value().clone())
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
|
@ -450,7 +450,7 @@ pub trait CryptoStore: AsyncTraitDeps {
|
||||||
) -> Result<Option<OutgoingKeyRequest>>;
|
) -> Result<Option<OutgoingKeyRequest>>;
|
||||||
|
|
||||||
/// Get all outgoing key requests that we have in the store.
|
/// Get all outgoing key requests that we have in the store.
|
||||||
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
|
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
|
||||||
|
|
||||||
/// Delete an outoing key request that we created that matches the given
|
/// Delete an outoing key request that we created that matches the given
|
||||||
/// request id.
|
/// request id.
|
||||||
|
|
|
@ -149,6 +149,7 @@ pub struct SledStore {
|
||||||
outbound_group_sessions: Tree,
|
outbound_group_sessions: Tree,
|
||||||
|
|
||||||
outgoing_key_requests: Tree,
|
outgoing_key_requests: Tree,
|
||||||
|
unsent_key_requests: Tree,
|
||||||
key_requests_by_info: Tree,
|
key_requests_by_info: Tree,
|
||||||
|
|
||||||
devices: Tree,
|
devices: Tree,
|
||||||
|
@ -215,6 +216,7 @@ impl SledStore {
|
||||||
let identities = db.open_tree("identities")?;
|
let identities = db.open_tree("identities")?;
|
||||||
|
|
||||||
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?;
|
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?;
|
||||||
|
let unsent_key_requests = db.open_tree("unsent_key_requests")?;
|
||||||
let key_requests_by_info = db.open_tree("key_requests_by_info")?;
|
let key_requests_by_info = db.open_tree("key_requests_by_info")?;
|
||||||
|
|
||||||
let session_cache = SessionStore::new();
|
let session_cache = SessionStore::new();
|
||||||
|
@ -240,6 +242,7 @@ impl SledStore {
|
||||||
inbound_group_sessions,
|
inbound_group_sessions,
|
||||||
outbound_group_sessions,
|
outbound_group_sessions,
|
||||||
outgoing_key_requests,
|
outgoing_key_requests,
|
||||||
|
unsent_key_requests,
|
||||||
key_requests_by_info,
|
key_requests_by_info,
|
||||||
devices,
|
devices,
|
||||||
tracked_users,
|
tracked_users,
|
||||||
|
@ -376,6 +379,7 @@ impl SledStore {
|
||||||
&self.outbound_group_sessions,
|
&self.outbound_group_sessions,
|
||||||
&self.olm_hashes,
|
&self.olm_hashes,
|
||||||
&self.outgoing_key_requests,
|
&self.outgoing_key_requests,
|
||||||
|
&self.unsent_key_requests,
|
||||||
&self.key_requests_by_info,
|
&self.key_requests_by_info,
|
||||||
)
|
)
|
||||||
.transaction(
|
.transaction(
|
||||||
|
@ -389,6 +393,7 @@ impl SledStore {
|
||||||
outbound_sessions,
|
outbound_sessions,
|
||||||
hashes,
|
hashes,
|
||||||
outgoing_key_requests,
|
outgoing_key_requests,
|
||||||
|
unsent_key_requests,
|
||||||
key_requests_by_info,
|
key_requests_by_info,
|
||||||
)| {
|
)| {
|
||||||
if let Some(a) = &account_pickle {
|
if let Some(a) = &account_pickle {
|
||||||
|
@ -463,11 +468,23 @@ impl SledStore {
|
||||||
key_request.request_id.encode(),
|
key_request.request_id.encode(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
let key_request_id = key_request.request_id.encode();
|
||||||
|
|
||||||
|
if key_request.sent_out {
|
||||||
|
unsent_key_requests.remove(key_request_id.clone())?;
|
||||||
outgoing_key_requests.insert(
|
outgoing_key_requests.insert(
|
||||||
key_request.request_id.encode(),
|
key_request_id,
|
||||||
serde_json::to_vec(&key_request)
|
serde_json::to_vec(&key_request)
|
||||||
.map_err(ConflictableTransactionError::Abort)?,
|
.map_err(ConflictableTransactionError::Abort)?,
|
||||||
)?;
|
)?;
|
||||||
|
} else {
|
||||||
|
outgoing_key_requests.remove(key_request_id.clone())?;
|
||||||
|
unsent_key_requests.insert(
|
||||||
|
key_request_id,
|
||||||
|
serde_json::to_vec(&key_request)
|
||||||
|
.map_err(ConflictableTransactionError::Abort)?,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -479,6 +496,28 @@ impl SledStore {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_outgoing_key_request_helper(
|
||||||
|
&self,
|
||||||
|
id: &[u8],
|
||||||
|
) -> Result<Option<OutgoingKeyRequest>> {
|
||||||
|
let request = self
|
||||||
|
.outgoing_key_requests
|
||||||
|
.get(id)?
|
||||||
|
.map(|r| serde_json::from_slice(&r))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let request = if request.is_none() {
|
||||||
|
self.unsent_key_requests
|
||||||
|
.get(id)?
|
||||||
|
.map(|r| serde_json::from_slice(&r))
|
||||||
|
.transpose()?
|
||||||
|
} else {
|
||||||
|
request
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
@ -685,11 +724,9 @@ impl CryptoStore for SledStore {
|
||||||
&self,
|
&self,
|
||||||
request_id: Uuid,
|
request_id: Uuid,
|
||||||
) -> Result<Option<OutgoingKeyRequest>> {
|
) -> Result<Option<OutgoingKeyRequest>> {
|
||||||
Ok(self
|
let request_id = request_id.encode();
|
||||||
.outgoing_key_requests
|
|
||||||
.get(request_id.encode())?
|
self.get_outgoing_key_request_helper(&request_id).await
|
||||||
.map(|r| serde_json::from_slice(&r))
|
|
||||||
.transpose()?)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_key_request_by_info(
|
async fn get_key_request_by_info(
|
||||||
|
@ -699,19 +736,15 @@ impl CryptoStore for SledStore {
|
||||||
let id = self.key_requests_by_info.get(key_info.encode())?;
|
let id = self.key_requests_by_info.get(key_info.encode())?;
|
||||||
|
|
||||||
if let Some(id) = id {
|
if let Some(id) = id {
|
||||||
Ok(self
|
self.get_outgoing_key_request_helper(&id).await
|
||||||
.outgoing_key_requests
|
|
||||||
.get(id)?
|
|
||||||
.map(|r| serde_json::from_slice(&r))
|
|
||||||
.transpose()?)
|
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_outgoing_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||||
let requests: Result<Vec<OutgoingKeyRequest>> = self
|
let requests: Result<Vec<OutgoingKeyRequest>> = self
|
||||||
.outgoing_key_requests
|
.unsent_key_requests
|
||||||
.iter()
|
.iter()
|
||||||
.map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from))
|
.map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from))
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -720,16 +753,30 @@ impl CryptoStore for SledStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||||
let ret: Result<(), TransactionError<serde_json::Error>> =
|
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
||||||
(&self.outgoing_key_requests, &self.key_requests_by_info).transaction(
|
&self.outgoing_key_requests,
|
||||||
|(outgoing_key_requests, key_requests_by_info)| {
|
&self.unsent_key_requests,
|
||||||
let request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
&self.key_requests_by_info,
|
||||||
|
)
|
||||||
|
.transaction(
|
||||||
|
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
|
||||||
|
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
||||||
.remove(request_id.encode())?
|
.remove(request_id.encode())?
|
||||||
.map(|r| serde_json::from_slice(&r))
|
.map(|r| serde_json::from_slice(&r))
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(ConflictableTransactionError::Abort)?;
|
.map_err(ConflictableTransactionError::Abort)?;
|
||||||
|
|
||||||
if let Some(request) = request {
|
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
|
||||||
|
.remove(request_id.encode())?
|
||||||
|
.map(|r| serde_json::from_slice(&r))
|
||||||
|
.transpose()
|
||||||
|
.map_err(ConflictableTransactionError::Abort)?;
|
||||||
|
|
||||||
|
if let Some(request) = sent_request {
|
||||||
|
key_requests_by_info.remove((&request.info).encode())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(request) = unsent_request {
|
||||||
key_requests_by_info.remove((&request.info).encode())?;
|
key_requests_by_info.remove((&request.info).encode())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1328,7 +1375,21 @@ mod test {
|
||||||
|
|
||||||
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||||
assert_eq!(request, stored_request);
|
assert_eq!(request, stored_request);
|
||||||
assert!(!store.get_outgoing_key_requests().await.unwrap().is_empty());
|
assert!(!store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||||
|
|
||||||
|
let request = OutgoingKeyRequest {
|
||||||
|
request_id: id,
|
||||||
|
info: info.clone(),
|
||||||
|
sent_out: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut changes = Changes::default();
|
||||||
|
changes.key_requests.push(request.clone());
|
||||||
|
store.save_changes(changes).await.unwrap();
|
||||||
|
|
||||||
|
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||||
|
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||||
|
assert_eq!(Some(request), stored_request);
|
||||||
|
|
||||||
store.delete_outgoing_key_request(id).await.unwrap();
|
store.delete_outgoing_key_request(id).await.unwrap();
|
||||||
|
|
||||||
|
@ -1337,6 +1398,6 @@ mod test {
|
||||||
|
|
||||||
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||||
assert_eq!(None, stored_request);
|
assert_eq!(None, stored_request);
|
||||||
assert!(store.get_outgoing_key_requests().await.unwrap().is_empty());
|
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue