crypto: Store our own device we receive from the server

This commit is contained in:
Damir Jelić 2021-08-04 16:38:43 +02:00
parent 7d851a10b5
commit 0598bdebc7
4 changed files with 82 additions and 20 deletions

View file

@ -938,11 +938,16 @@ mod test {
let user_id: Arc<UserId> = alice_id().into();
let account = ReadOnlyAccount::new(&user_id, &alice_device_id());
let device = ReadOnlyDevice::from_account(&account).await;
let another_device =
ReadOnlyDevice::from_account(&ReadOnlyAccount::new(&user_id, &alice2_device_id()))
.await;
let store: Arc<dyn CryptoStore> = Arc::new(MemoryStore::new());
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
let store = Store::new(user_id.clone(), identity, store, verification);
store.save_devices(&[device]).await.unwrap();
store.save_devices(&[device, another_device]).await.unwrap();
let session_cache = GroupSessionCache::new(store.clone());
GossipMachine::new(
@ -1136,7 +1141,7 @@ mod test {
let account = account();
let own_device =
machine.store.get_device(&alice_id(), &alice_device_id()).await.unwrap().unwrap();
machine.store.get_device(&alice_id(), &alice2_device_id()).await.unwrap().unwrap();
let (outbound, inbound) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();

View file

@ -137,18 +137,20 @@ impl IdentityManager {
user_id: UserId,
device_map: BTreeMap<DeviceIdBox, DeviceKeys>,
) -> StoreResult<DeviceChanges> {
let own_device_id = (&*own_device_id).to_owned();
let mut changes = DeviceChanges::default();
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map.into_iter().filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && *device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
user_id = user_id.as_str(),
device_id = device_id.as_str(),
device_key_user = device_keys.user_id.as_str(),
device_key_device_id = device_keys.device_id.as_str(),
"Mismatch in the device keys payload",
);
None
} else {
@ -169,13 +171,14 @@ impl IdentityManager {
}
let current_devices: HashSet<&DeviceIdBox> = current_devices.iter().collect();
let stored_devices = store.get_readonly_devices(&user_id).await?;
let stored_devices = store.get_readonly_devices_unfiltered(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices_set {
if let Some(device) = stored_devices.get(*device_id) {
if user_id == *own_user_id && *device_id == &own_device_id {
warn!("Our own device has been deleted");
} else if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
changes.deleted.push(device.clone());
}

View file

@ -287,6 +287,11 @@ impl OlmMachine {
self.account.identity_keys()
}
/// Get the display name of our own device
pub async fn dislpay_name(&self) -> StoreResult<Option<String>> {
self.store.device_display_name().await
}
/// Get the outgoing requests that need to be sent out.
///
/// This returns a list of `OutGoingRequest`, those requests need to be sent

View file

@ -162,19 +162,19 @@ impl Store {
Self { user_id, identity, inner: store, verification_machine }
}
pub fn user_id(&self) -> &UserId {
&self.user_id
}
pub fn device_id(&self) -> &DeviceId {
self.verification_machine.own_device_id()
}
#[cfg(test)]
pub async fn reset_cross_signing_identity(&self) {
self.identity.lock().await.reset().await;
}
pub async fn get_readonly_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>> {
self.inner.get_device(user_id, device_id).await
}
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
@ -201,13 +201,58 @@ impl Store {
self.save_changes(changes).await
}
/// Get the display name of our own device.
pub async fn device_display_name(&self) -> Result<Option<String>, CryptoStoreError> {
Ok(self
.inner
.get_device(self.user_id(), self.device_id())
.await?
.and_then(|d| d.display_name().to_owned()))
}
/// Get the read-only version of all the devices that the given user has.
///
/// *Note*: This doesn't return our own device.
pub async fn get_readonly_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>> {
if user_id == self.user_id() && device_id == self.device_id() {
Ok(None)
} else {
self.inner.get_device(user_id, device_id).await
}
}
/// Get the read-only version of all the devices that the given user has.
///
/// *Note*: This doesn't return our own device.
pub async fn get_readonly_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
self.inner.get_user_devices(user_id).await.map(|mut d| {
if user_id == self.user_id() {
d.remove(self.device_id());
}
d
})
}
/// Get the read-only version of all the devices that the given user has.
///
/// *Note*: This does also return our own device.
pub async fn get_readonly_devices_unfiltered(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
self.inner.get_user_devices(user_id).await
}
/// Get a device for the given user with the given curve25519 key.
///
/// *Note*: This doesn't return our own device.
pub async fn get_device_from_curve_key(
&self,
user_id: &UserId,
@ -221,7 +266,11 @@ impl Store {
}
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?;
let mut devices = self.inner.get_user_devices(user_id).await?;
if user_id == self.user_id() {
devices.remove(self.device_id());
}
let own_identity =
self.inner.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();