From adf7b5929401f56bedba92ef778b5e56feefc479 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Tue, 28 Jul 2020 17:38:30 +0100 Subject: [PATCH] Persist partition|offset|user_id in the keyserver (#1226) * Persist partition|offset|user_id in the keyserver Required for a query API which will be used by the syncapi which will be called when a `/sync` request comes in which will return a list of user IDs of people who have changed their device keys between two tokens. * Add tests and fix maxOffset bug * s/offset/log_offset/g because 'offset' is a reserved word in postgres --- keyserver/keyserver.go | 1 + keyserver/producers/keychange.go | 7 ++ keyserver/storage/interface.go | 8 ++ .../storage/postgres/key_changes_table.go | 97 ++++++++++++++++++ keyserver/storage/postgres/storage.go | 5 + keyserver/storage/shared/storage.go | 9 ++ .../storage/sqlite3/key_changes_table.go | 98 +++++++++++++++++++ keyserver/storage/sqlite3/storage.go | 5 + keyserver/storage/storage_test.go | 57 +++++++++++ keyserver/storage/tables/interface.go | 5 + 10 files changed, 292 insertions(+) create mode 100644 keyserver/storage/postgres/key_changes_table.go create mode 100644 keyserver/storage/sqlite3/key_changes_table.go create mode 100644 keyserver/storage/storage_test.go diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 47c6a8c3..c748d7ce 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -49,6 +49,7 @@ func NewInternalAPI( keyChangeProducer := &producers.KeyChange{ Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent), Producer: producer, + DB: db, } return &internal.KeyInternalAPI{ DB: db, diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index 6683a936..d59dd200 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -15,10 +15,12 @@ package producers import ( + "context" "encoding/json" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage" "github.com/sirupsen/logrus" ) @@ -26,6 +28,7 @@ import ( type KeyChange struct { Topic string Producer sarama.SyncProducer + DB storage.Database } // ProduceKeyChanges creates new change events for each key @@ -46,6 +49,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { if err != nil { return err } + err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID) + if err != nil { + return err + } logrus.WithFields(logrus.Fields{ "user_id": key.UserID, "device_id": key.DeviceID, diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 7a0328bd..f4787790 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -43,4 +43,12 @@ type Database interface { // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed + // their keys in some way. + StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) } diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go new file mode 100644 index 00000000..9d259f9f --- /dev/null +++ b/keyserver/storage/postgres/key_changes_table.go @@ -0,0 +1,97 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + log_offset BIGINT NOT NULL, + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) +); +` + +// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. +// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will +// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" + + " DO UPDATE SET user_id = $3" + +// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just +// take the max offset value as the latest offset. +const selectKeyChangesSQL = "" + + "SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return nil, err + } + if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { + return nil, err + } + if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, partition int32, fromOffset int64, +) (userIDs []string, latestOffset int64, err error) { + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 4f3217b6..a1d1c0fe 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -34,9 +34,14 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s if err != nil { return nil, err } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } return &shared.Database{ DB: db, OneTimeKeysTable: otk, DeviceKeysTable: dk, + KeyChangesTable: kc, }, nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 156b5b41..537a5f7b 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -28,6 +28,7 @@ type Database struct { DB *sql.DB OneTimeKeysTable tables.OneTimeKeys DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges } func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { @@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st }) return result, err } + +func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { + return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID) +} + +func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset) +} diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go new file mode 100644 index 00000000..b830214d --- /dev/null +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -0,0 +1,98 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) +); +` + +// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. +// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will +// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT (partition, offset)" + + " DO UPDATE SET user_id = $3" + +// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just +// take the max offset value as the latest offset. +const selectKeyChangesSQL = "" + + "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return nil, err + } + if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { + return nil, err + } + if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, partition int32, fromOffset int64, +) (userIDs []string, latestOffset int64, err error) { + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index f3566ef5..f9771cf1 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -37,9 +37,14 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) { if err != nil { return nil, err } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } return &shared.Database{ DB: db, OneTimeKeysTable: otk, DeviceKeysTable: dk, + KeyChangesTable: kc, }, nil } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go new file mode 100644 index 00000000..88972478 --- /dev/null +++ b/keyserver/storage/storage_test.go @@ -0,0 +1,57 @@ +package storage + +import ( + "context" + "reflect" + "testing" +) + +var ctx = context.Background() + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) + userIDs, latest, err := db.KeyChanges(ctx, 0, 1) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != 2 { + t.Fatalf("KeyChanges: got latest=%d want 2", latest) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } +} + +func TestKeyChangesNoDupes(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost")) + userIDs, latest, err := db.KeyChanges(ctx, 0, 0) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != 2 { + t.Fatalf("KeyChanges: got latest=%d want 2", latest) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 216be773..824b9f0f 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -35,3 +35,8 @@ type DeviceKeys interface { InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) } + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) +}