diff --git a/keyserver/api/api.go b/keyserver/api/api.go index c3481a38..442af871 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -110,6 +110,11 @@ type OneTimeKeysCount struct { type PerformUploadKeysRequest struct { DeviceKeys []DeviceKeys OneTimeKeys []OneTimeKeys + // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update + // the display name for their respective device, and NOT to modify the keys. The key + // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. + // Without this flag, requests to modify device display names would delete device keys. + OnlyDisplayNameUpdates bool } // PerformUploadKeysResponse is the response to PerformUploadKeys diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 1c4f0b97..573285e8 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -85,8 +85,9 @@ type DeviceListUpdaterDatabase interface { MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) @@ -144,6 +145,20 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { return u.userIDToMutex[userID] } +// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. +// Blocks until the device list is synced or the timeout is reached. +func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { + mu := u.mutex(userID) + mu.Lock() + err := u.db.MarkDeviceListStale(ctx, userID, true) + mu.Unlock() + if err != nil { + return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err) + } + u.notifyWorkers(userID) + return nil +} + // Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest, // which assumes when /send 200 OKs that the device lists have been updated. func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { @@ -178,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. "stream_id": event.StreamID, "prev_ids": event.PrevID, "display_name": event.DeviceDisplayName, + "deleted": event.Deleted, }).Info("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users if exists { + k := event.Keys + if event.Deleted { + k = nil + } keys := []api.DeviceMessage{ { DeviceKeys: api.DeviceKeys{ DeviceID: event.DeviceID, DisplayName: event.DeviceDisplayName, - KeyJSON: event.Keys, + KeyJSON: k, UserID: event.UserID, }, StreamID: event.StreamID, }, } - err = u.db.StoreRemoteDeviceKeys(ctx, keys) + err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil) if err != nil { return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) } @@ -348,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi }, } } - err := u.db.StoreRemoteDeviceKeys(ctx, keys) + err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) if err != nil { return fmt.Errorf("failed to store remote device keys: %w", err) } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index dcb981c4..c42a7cdf 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). Does not modify the stream ID for keys. -func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { +func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { d.storedKeys = append(d.storedKeys, keys...) return nil } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 075622b7..ef52d014 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -28,6 +28,7 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -205,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query maxStreamID = m.StreamID } } - res.Devices = msgs + // remove deleted devices + var result []api.DeviceMessage + for _, m := range msgs { + if m.KeyJSON == nil { + continue + } + result = append(result, m) + } + res.Devices = result res.StreamID = maxStreamID } @@ -282,27 +291,21 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( fetchRemote := make(map[string]map[string][]string) for domain, userToDeviceMap := range domainToDeviceKeys { for userID, deviceIDs := range userToDeviceMap { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) - // if we can't query the db or there are fewer keys than requested, fetch from remote. - // Likewise, we can't safely return keys from the db when all devices are requested as we don't + // we can't safely return keys from the db when all devices are requested as we don't // know if one has just been added. - if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) { - if _, ok := fetchRemote[domain]; !ok { - fetchRemote[domain] = make(map[string][]string) + if len(deviceIDs) > 0 { + err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs) + if err == nil { + continue } - fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) - continue + util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase") } - if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - } - for _, key := range keys { - // inject the display name - key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{key.DisplayName}) - res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + // fetch device lists from remote + if _, ok := fetchRemote[domain]; !ok { + fetchRemote[domain] = make(map[string][]string) } + fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) + } } return fetchRemote @@ -324,6 +327,45 @@ func (a *KeyInternalAPI) queryRemoteKeys( defer wg.Done() fedCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() + // for users who we do not have any knowledge about, try to start doing device list updates for them + // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but + // lack a stream ID. + var userIDsForAllDevices []string + for userID, deviceIDs := range devKeys { + if len(deviceIDs) == 0 { + userIDsForAllDevices = append(userIDsForAllDevices, userID) + delete(devKeys, userID) + } + } + for _, userID := range userIDsForAllDevices { + err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + logrus.ErrorKey: err, + "user_id": userID, + "server": serverName, + }).Error("Failed to manually update device lists for user") + // try to do it via /keys/query + devKeys[userID] = []string{} + continue + } + // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this + // user so the fact that we're populating all devices here isn't a problem so long as we have devices. + err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil) + if err != nil { + logrus.WithFields(logrus.Fields{ + logrus.ErrorKey: err, + "user_id": userID, + "server": serverName, + }).Error("Failed to manually update device lists for user") + // try to do it via /keys/query + devKeys[userID] = []string{} + continue + } + } + if len(devKeys) == 0 { + return + } queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) if err != nil { failMu.Lock() @@ -357,6 +399,37 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } +func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, +) error { + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + // if we can't query the db or there are fewer keys than requested, fetch from remote. + if err != nil { + return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) + } + if len(keys) < len(deviceIDs) { + return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID) + } + if len(deviceIDs) == 0 && len(keys) == 0 { + return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) + } + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + + for _, key := range keys { + if len(key.KeyJSON) == 0 { + continue // ignore deleted keys + } + // inject the display name + key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{key.DisplayName}) + res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + } + return nil +} + func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key @@ -403,6 +476,10 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } return } + if req.OnlyDisplayNameUpdates { + // add the display name field from keysToStore into existingKeys + keysToStore = appendDisplayNames(existingKeys, keysToStore) + } // store the device keys and emit changes err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { @@ -475,3 +552,16 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) } return a.Producer.ProduceKeyChanges(keysAdded) } + +func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage { + for i, existingDevice := range existing { + for _, newDevice := range new { + if existingDevice.DeviceID != newDevice.DeviceID { + continue + } + existingDevice.DisplayName = newDevice.DisplayName + existing[i] = existingDevice + } + } + return existing +} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 2a60aacc..0ec62f56 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -43,8 +43,9 @@ type Database interface { StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index b9d5d4c3..779d02c0 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" + const countStreamIDsForUserSQL = "" + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + type deviceKeysStatements struct { db *sql.DB upsertDeviceKeysStmt *sql.Stmt @@ -68,6 +71,7 @@ type deviceKeysStatements struct { selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt countStreamIDsForUserStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { return nil, err } + if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { + return nil, err + } return s, nil } @@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx return nil } +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 4279eae7..a4c35a4b 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i return count == len(prevIDs), nil } -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { +func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) }) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index abe6636a..a4d71fe1 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" + const countStreamIDsForUserSQL = "" + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + type deviceKeysStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -65,6 +68,7 @@ type deviceKeysStatements struct { selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } + if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { + return nil, err + } return s, nil } +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index a4d5dede..f97e871f 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -38,6 +38,7 @@ type DeviceKeys interface { SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } type KeyChanges interface { diff --git a/sytest-whitelist b/sytest-whitelist index bbac6972..d22b408a 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -146,6 +146,8 @@ If remote user leaves room we no longer receive device updates If a device list update goes missing, the server resyncs on the next one Get left notifs in sync and /keys/changes when other user leaves Can query remote device keys using POST after notification +Server correctly resyncs when client query keys and there is no remote cache +Server correctly resyncs when server leaves and rejoins a room Can add account data Can add account data to room Can get account data without syncing diff --git a/userapi/internal/api.go b/userapi/internal/api.go index f58c7113..05cecc1b 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -180,6 +180,27 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err } + if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { + // display name has changed: update the device key + var uploadRes keyapi.PerformUploadKeysResponse + a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + DeviceKeys: []keyapi.DeviceKeys{ + { + DeviceID: dev.ID, + DisplayName: *req.DisplayName, + KeyJSON: nil, + UserID: dev.UserID, + }, + }, + OnlyDisplayNameUpdates: true, + }, &uploadRes) + if uploadRes.Error != nil { + return fmt.Errorf("Failed to update device key display name: %v", uploadRes.Error) + } + if len(uploadRes.KeyErrors) > 0 { + return fmt.Errorf("Failed to update device key display name, key errors: %+v", uploadRes.KeyErrors) + } + } return nil }