diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 19d8463d..ec7dff56 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -23,7 +23,6 @@ import ( "time" "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -65,7 +64,7 @@ type DeviceListUpdater struct { mu *sync.Mutex // protects UserIDToMutex db DeviceListUpdaterDatabase - producer *producers.KeyChange + producer KeyChangeProducer fedClient *gomatrixserverlib.FederationClient workerChans []chan gomatrixserverlib.ServerName } @@ -88,9 +87,14 @@ type DeviceListUpdaterDatabase interface { PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) } +// KeyChangeProducer is the interface for producers.KeyChange useful for testing. +type KeyChangeProducer interface { + ProduceKeyChanges(keys []api.DeviceMessage) error +} + // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. func NewDeviceListUpdater( - db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient, + db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient *gomatrixserverlib.FederationClient, numWorkers int, ) *DeviceListUpdater { return &DeviceListUpdater{ @@ -154,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. if err != nil { return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) } + // if this is the first time we're hearing about this user, sync the device list manually. + if len(event.PrevID) == 0 { + exists = false + } util.GetLogger(ctx).WithFields(logrus.Fields{ "prev_ids_exist": exists, "user_id": event.UserID, "device_id": event.DeviceID, "stream_id": event.StreamID, "prev_ids": event.PrevID, + "display_name": event.DeviceDisplayName, }).Info("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users @@ -263,16 +272,17 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam hasFailures = true continue } - err = u.updateDeviceList(ctx, &res) + err = u.updateDeviceList(&res) if err != nil { - logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it") + logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it") hasFailures = true } } return hasFailures } -func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error { +func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { + ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. keys := make([]api.DeviceMessage, len(res.Devices)) for i, device := range res.Devices { keyJSON, err := json.Marshal(device.Keys) @@ -292,7 +302,15 @@ func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixs } err := u.db.StoreRemoteDeviceKeys(ctx, keys) if err != nil { - return err + return fmt.Errorf("failed to store remote device keys: %w", err) } - return u.db.MarkDeviceListStale(ctx, res.UserID, false) + err = u.db.MarkDeviceListStale(ctx, res.UserID, false) + if err != nil { + return fmt.Errorf("failed to mark device list as fresh: %w", err) + } + err = u.producer.ProduceKeyChanges(keys) + if err != nil { + return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) + } + return nil } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go new file mode 100644 index 00000000..50e42763 --- /dev/null +++ b/keyserver/internal/device_list_update_test.go @@ -0,0 +1,242 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "crypto/ed25519" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +var ( + ctx = context.Background() +) + +type mockKeyChangeProducer struct { + events []api.DeviceMessage +} + +func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error { + p.events = append(p.events, keys...) + return nil +} + +type mockDeviceListUpdaterDatabase struct { + staleUsers map[string]bool + prevIDsExist func(string, []int) bool + storedKeys []api.DeviceMessage +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + var result []string + for userID := range d.staleUsers { + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return nil, err + } + if len(domains) == 0 { + result = append(result, userID) + continue + } + for _, d := range domains { + if remoteServer == d { + result = append(result, userID) + break + } + } + } + return result, nil +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.staleUsers[userID] = isStale + return nil +} + +// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key +// for this (user, device). Does not modify the stream ID for keys. +func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + d.storedKeys = append(d.storedKeys, keys...) + return nil +} + +// PrevIDsExists returns true if all prev IDs exist for this user. +func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { + return d.prevIDsExist(userID, prevIDs), nil +} + +type roundTripper struct { + fn func(*http.Request) (*http.Response, error) +} + +func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.fn(req) +} + +func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { + _, pkey, _ := ed25519.GenerateKey(nil) + fedClient := gomatrixserverlib.NewFederationClient( + gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, + ) + fedClient.Client = *gomatrixserverlib.NewClientWithTransport(&roundTripper{tripper}) + return fedClient +} + +// Test that the device keys get persisted and emitted if we have the previous IDs. +func TestUpdateHavePrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return true + }, + } + producer := &mockKeyChangeProducer{} + updater := NewDeviceListUpdater(db, producer, nil, 1) + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Foo Bar", + Deleted: false, + DeviceID: "FOO", + Keys: []byte(`{"key":"value"}`), + PrevID: []int{0}, + StreamID: 1, + UserID: "@alice:localhost", + } + err := updater.Update(ctx, event) + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + want := api.DeviceMessage{ + StreamID: event.StreamID, + DeviceKeys: api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: event.Keys, + UserID: event.UserID, + }, + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + if db.staleUsers[event.UserID] { + t.Errorf("%s incorrectly marked as stale", event.UserID) + } +} + +// Test that device keys are fetched from the remote server if we are missing prev IDs +// and that the user's devices are marked as stale until it succeeds. +func TestUpdateNoPrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return false + }, + } + producer := &mockKeyChangeProducer{} + remoteUserID := "@alice:example.somewhere" + var wg sync.WaitGroup + wg.Add(1) + keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + defer wg.Done() + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(` + { + "user_id": "` + remoteUserID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + }, nil + }) + updater := NewDeviceListUpdater(db, producer, fedClient, 2) + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Mobile Phone", + Deleted: false, + DeviceID: "another_device_id", + Keys: []byte(`{"key":"value"}`), + PrevID: []int{3}, + StreamID: 4, + UserID: remoteUserID, + } + err := updater.Update(ctx, event) + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + // At this point we show have this device list marked as stale and not store the keys or emitted anything + if !db.staleUsers[event.UserID] { + t.Errorf("%s not marked as stale", event.UserID) + } + if len(producer.events) > 0 { + t.Errorf("Update incorrect emitted %d device change events", len(producer.events)) + } + if len(db.storedKeys) > 0 { + t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys)) + } + t.Log("waiting for /users/devices to be called...") + wg.Wait() + // wait a bit for db to be updated... + time.Sleep(100 * time.Millisecond) + want := api.DeviceMessage{ + StreamID: 5, + DeviceKeys: api.DeviceKeys{ + DeviceID: "JLAFKJWSCS", + DisplayName: "Mobile Phone", + UserID: remoteUserID, + KeyJSON: []byte(keyJSON), + }, + } + // Now we should have a fresh list and the keys and emitted something + if db.staleUsers[event.UserID] { + t.Errorf("%s still marked as stale", event.UserID) + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON)) + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index ff298c07..075622b7 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques if len(dk.KeyJSON) == 0 { continue // don't include blank keys } - // inject display name if known + // inject display name if known (either locally or remotely) + displayName := dk.DisplayName + if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { + displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + } dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` - }{queryRes.DeviceInfo[dk.DeviceID].DisplayName}) + }{displayName}) res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON } } else { @@ -261,12 +265,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) } } - // TODO: set device display names when they are known + + // attempt to satisfy key queries from the local database first as we should get device updates pushed to us + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) + if len(domainToDeviceKeys) == 0 { + return // nothing to query + } // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } +func (a *KeyInternalAPI) remoteKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, +) map[string]map[string][]string { + fetchRemote := make(map[string]map[string][]string) + for domain, userToDeviceMap := range domainToDeviceKeys { + for userID, deviceIDs := range userToDeviceMap { + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + // if we can't query the db or there are fewer keys than requested, fetch from remote. + // Likewise, we can't safely return keys from the db when all devices are requested as we don't + // know if one has just been added. + if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) { + if _, ok := fetchRemote[domain]; !ok { + fetchRemote[domain] = make(map[string][]string) + } + fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) + continue + } + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, key := range keys { + // inject the display name + key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{key.DisplayName}) + res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + } + } + } + return fetchRemote +} + func (a *KeyInternalAPI) queryRemoteKeys( ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, ) { diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d321860d..b9d5d4c3 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( -- required in the spec because in the event of a missed update the server fetches the entire -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err @@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go new file mode 100644 index 00000000..63281adf --- /dev/null +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index a1d1c0fe..de2fabfd 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -38,10 +38,15 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s if err != nil { return nil, err } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 68964be6..4279eae7 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -26,10 +26,11 @@ import ( ) type Database struct { - DB *sql.DB - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges + DB *sql.DB + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists } func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { @@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return nil, nil // TODO + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) } // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return nil // TODO + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 15d9c775..abe6636a 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index f910479f..907966a7 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -196,6 +196,9 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) return err }) + if keyJSON == "" { + return nil, nil + } return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go new file mode 100644 index 00000000..a989476d --- /dev/null +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index f9771cf1..bbfd1e79 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -41,10 +41,15 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) { if err != nil { return nil, err } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index ac932d56..a4d5dede 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" ) type OneTimeKeys interface { @@ -45,3 +46,8 @@ type KeyChanges interface { // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) } + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) +} diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 66134d79..e0379aaf 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -46,6 +46,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. +// nolint:gocyclo func DeviceListCatchup( ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, res *types.Response, from, to types.StreamingToken, @@ -68,22 +69,20 @@ func DeviceListCatchup( var partition int32 var offset int64 + partition = -1 + offset = sarama.OffsetOldest // Extract partition/offset from sync token // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. logOffset := from.Log(DeviceListLogName) if logOffset != nil { partition = logOffset.Partition offset = logOffset.Offset - } else { - partition = -1 - offset = sarama.OffsetOldest } var toOffset int64 + toOffset = sarama.OffsetNewest toLog := to.Log(DeviceListLogName) - if toLog != nil { + if toLog != nil && toLog.Offset > 0 { toOffset = toLog.Offset - } else { - toOffset = sarama.OffsetNewest } var queryRes api.QueryKeyChangesResponse keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ @@ -96,6 +95,10 @@ func DeviceListCatchup( util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") return hasNew, nil } + util.GetLogger(ctx).Debugf( + "QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v", + partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs, + ) userSet := make(map[string]bool) for _, userID := range res.DeviceLists.Changed { userSet[userID] = true @@ -116,6 +119,13 @@ func DeviceListCatchup( userSet[userID] = true } } + // set the new token + to.SetLog(DeviceListLogName, &types.LogPosition{ + Partition: queryRes.Partition, + Offset: queryRes.Offset, + }) + res.NextBatch = to.String() + return hasNew, nil } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index f465d9ff..f3324800 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -112,6 +112,9 @@ type StreamingToken struct { } func (t *StreamingToken) SetLog(name string, lp *LogPosition) { + if t.logs == nil { + t.logs = make(map[string]*LogPosition) + } t.logs[name] = lp } @@ -173,12 +176,14 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) } ret.Positions[i] = other.Positions[i] } + ret.logs = make(map[string]*LogPosition) for name := range t.logs { otherLog := other.Log(name) if otherLog == nil { continue } - t.logs[name] = otherLog + copy := *otherLog + ret.logs[name] = © } return ret } diff --git a/sytest-whitelist b/sytest-whitelist index 18978bbe..cc49bf38 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -138,6 +138,7 @@ Users receive device_list updates for their own devices Get left notifs for other users in sync and /keys/changes when user leaves Local device key changes get to remote servers Local device key changes get to remote servers with correct prev_id +#Server correctly handles incoming m.device_list_update Can add account data Can add account data to room Can get account data without syncing