Fix more E2E sytests (#1265)
* WIP: Eagerly sync device lists on /user/keys/query requests Also notify servers when a user's device display name changes. Few caveats: - sytest `Device deletion propagates over federation` fails - `populateResponseWithDeviceKeysFromDatabase` is called from multiple goroutines and hence is unsafe. * Handle deleted devices correctly over federation
This commit is contained in:
parent
d98ec12422
commit
820c56c165
11 changed files with 197 additions and 27 deletions
|
@ -110,6 +110,11 @@ type OneTimeKeysCount struct {
|
|||
type PerformUploadKeysRequest struct {
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||
// the display name for their respective device, and NOT to modify the keys. The key
|
||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||
// Without this flag, requests to modify device display names would delete device keys.
|
||||
OnlyDisplayNameUpdates bool
|
||||
}
|
||||
|
||||
// PerformUploadKeysResponse is the response to PerformUploadKeys
|
||||
|
|
|
@ -85,8 +85,9 @@ type DeviceListUpdaterDatabase interface {
|
|||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
|
@ -144,6 +145,20 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
|||
return u.userIDToMutex[userID]
|
||||
}
|
||||
|
||||
// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
|
||||
// Blocks until the device list is synced or the timeout is reached.
|
||||
func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
|
||||
mu := u.mutex(userID)
|
||||
mu.Lock()
|
||||
err := u.db.MarkDeviceListStale(ctx, userID, true)
|
||||
mu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
|
||||
}
|
||||
u.notifyWorkers(userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
|
||||
// which assumes when /send 200 OKs that the device lists have been updated.
|
||||
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
|
||||
|
@ -178,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
|||
"stream_id": event.StreamID,
|
||||
"prev_ids": event.PrevID,
|
||||
"display_name": event.DeviceDisplayName,
|
||||
"deleted": event.Deleted,
|
||||
}).Info("DeviceListUpdater.Update")
|
||||
|
||||
// if we haven't missed anything update the database and notify users
|
||||
if exists {
|
||||
k := event.Keys
|
||||
if event.Deleted {
|
||||
k = nil
|
||||
}
|
||||
keys := []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: event.Keys,
|
||||
KeyJSON: k,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
StreamID: event.StreamID,
|
||||
},
|
||||
}
|
||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys)
|
||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
|
@ -348,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
|
|||
},
|
||||
}
|
||||
}
|
||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
|
||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store remote device keys: %w", err)
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context,
|
|||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
||||
d.storedKeys = append(d.storedKeys, keys...)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
@ -205,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
|||
maxStreamID = m.StreamID
|
||||
}
|
||||
}
|
||||
res.Devices = msgs
|
||||
// remove deleted devices
|
||||
var result []api.DeviceMessage
|
||||
for _, m := range msgs {
|
||||
if m.KeyJSON == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
res.Devices = result
|
||||
res.StreamID = maxStreamID
|
||||
}
|
||||
|
||||
|
@ -282,27 +291,21 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
|||
fetchRemote := make(map[string]map[string][]string)
|
||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||
for userID, deviceIDs := range userToDeviceMap {
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
|
||||
// we can't safely return keys from the db when all devices are requested as we don't
|
||||
// know if one has just been added.
|
||||
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
|
||||
if _, ok := fetchRemote[domain]; !ok {
|
||||
fetchRemote[domain] = make(map[string][]string)
|
||||
if len(deviceIDs) > 0 {
|
||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
||||
continue
|
||||
util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
|
||||
}
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
for _, key := range keys {
|
||||
// inject the display name
|
||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
// fetch device lists from remote
|
||||
if _, ok := fetchRemote[domain]; !ok {
|
||||
fetchRemote[domain] = make(map[string][]string)
|
||||
}
|
||||
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
||||
|
||||
}
|
||||
}
|
||||
return fetchRemote
|
||||
|
@ -324,6 +327,45 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
|||
defer wg.Done()
|
||||
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
// for users who we do not have any knowledge about, try to start doing device list updates for them
|
||||
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
|
||||
// lack a stream ID.
|
||||
var userIDsForAllDevices []string
|
||||
for userID, deviceIDs := range devKeys {
|
||||
if len(deviceIDs) == 0 {
|
||||
userIDsForAllDevices = append(userIDsForAllDevices, userID)
|
||||
delete(devKeys, userID)
|
||||
}
|
||||
}
|
||||
for _, userID := range userIDsForAllDevices {
|
||||
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
"server": serverName,
|
||||
}).Error("Failed to manually update device lists for user")
|
||||
// try to do it via /keys/query
|
||||
devKeys[userID] = []string{}
|
||||
continue
|
||||
}
|
||||
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
||||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
"server": serverName,
|
||||
}).Error("Failed to manually update device lists for user")
|
||||
// try to do it via /keys/query
|
||||
devKeys[userID] = []string{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(devKeys) == 0 {
|
||||
return
|
||||
}
|
||||
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
|
||||
if err != nil {
|
||||
failMu.Lock()
|
||||
|
@ -357,6 +399,37 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
|||
}
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
||||
) error {
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||
}
|
||||
if len(keys) < len(deviceIDs) {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
|
||||
}
|
||||
if len(deviceIDs) == 0 && len(keys) == 0 {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||
}
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if len(key.KeyJSON) == 0 {
|
||||
continue // ignore deleted keys
|
||||
}
|
||||
// inject the display name
|
||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
var keysToStore []api.DeviceMessage
|
||||
// assert that the user ID / device ID are not lying for each key
|
||||
|
@ -403,6 +476,10 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
}
|
||||
return
|
||||
}
|
||||
if req.OnlyDisplayNameUpdates {
|
||||
// add the display name field from keysToStore into existingKeys
|
||||
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
||||
}
|
||||
// store the device keys and emit changes
|
||||
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
if err != nil {
|
||||
|
@ -475,3 +552,16 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage)
|
|||
}
|
||||
return a.Producer.ProduceKeyChanges(keysAdded)
|
||||
}
|
||||
|
||||
func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
|
||||
for i, existingDevice := range existing {
|
||||
for _, newDevice := range new {
|
||||
if existingDevice.DeviceID != newDevice.DeviceID {
|
||||
continue
|
||||
}
|
||||
existingDevice.DisplayName = newDevice.DisplayName
|
||||
existing[i] = existingDevice
|
||||
}
|
||||
}
|
||||
return existing
|
||||
}
|
||||
|
|
|
@ -43,8 +43,9 @@ type Database interface {
|
|||
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
|
|
|
@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" +
|
|||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
|
@ -68,6 +71,7 @@ type deviceKeysStatements struct {
|
|||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
|
@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
|
|
|
@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
|
|||
return count == len(prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
|
||||
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||
for _, userID := range clearUserIDs {
|
||||
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
|
|||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
|
@ -65,6 +68,7 @@ type deviceKeysStatements struct {
|
|||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
|
@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
|
|
|
@ -38,6 +38,7 @@ type DeviceKeys interface {
|
|||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
||||
type KeyChanges interface {
|
||||
|
|
|
@ -146,6 +146,8 @@ If remote user leaves room we no longer receive device updates
|
|||
If a device list update goes missing, the server resyncs on the next one
|
||||
Get left notifs in sync and /keys/changes when other user leaves
|
||||
Can query remote device keys using POST after notification
|
||||
Server correctly resyncs when client query keys and there is no remote cache
|
||||
Server correctly resyncs when server leaves and rejoins a room
|
||||
Can add account data
|
||||
Can add account data to room
|
||||
Can get account data without syncing
|
||||
|
|
|
@ -180,6 +180,27 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
|||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||
return err
|
||||
}
|
||||
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
|
||||
// display name has changed: update the device key
|
||||
var uploadRes keyapi.PerformUploadKeysResponse
|
||||
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||
DeviceKeys: []keyapi.DeviceKeys{
|
||||
{
|
||||
DeviceID: dev.ID,
|
||||
DisplayName: *req.DisplayName,
|
||||
KeyJSON: nil,
|
||||
UserID: dev.UserID,
|
||||
},
|
||||
},
|
||||
OnlyDisplayNameUpdates: true,
|
||||
}, &uploadRes)
|
||||
if uploadRes.Error != nil {
|
||||
return fmt.Errorf("Failed to update device key display name: %v", uploadRes.Error)
|
||||
}
|
||||
if len(uploadRes.KeyErrors) > 0 {
|
||||
return fmt.Errorf("Failed to update device key display name, key errors: %+v", uploadRes.KeyErrors)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue