Persist partition|offset|user_id in the keyserver (#1226)
* Persist partition|offset|user_id in the keyserver Required for a query API which will be used by the syncapi which will be called when a `/sync` request comes in which will return a list of user IDs of people who have changed their device keys between two tokens. * Add tests and fix maxOffset bug * s/offset/log_offset/g because 'offset' is a reserved word in postgresmain
parent
acc8e80a51
commit
adf7b59294
|
@ -49,6 +49,7 @@ func NewInternalAPI(
|
||||||
keyChangeProducer := &producers.KeyChange{
|
keyChangeProducer := &producers.KeyChange{
|
||||||
Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent),
|
Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent),
|
||||||
Producer: producer,
|
Producer: producer,
|
||||||
|
DB: db,
|
||||||
}
|
}
|
||||||
return &internal.KeyInternalAPI{
|
return &internal.KeyInternalAPI{
|
||||||
DB: db,
|
DB: db,
|
||||||
|
|
|
@ -15,10 +15,12 @@
|
||||||
package producers
|
package producers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,6 +28,7 @@ import (
|
||||||
type KeyChange struct {
|
type KeyChange struct {
|
||||||
Topic string
|
Topic string
|
||||||
Producer sarama.SyncProducer
|
Producer sarama.SyncProducer
|
||||||
|
DB storage.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProduceKeyChanges creates new change events for each key
|
// ProduceKeyChanges creates new change events for each key
|
||||||
|
@ -46,6 +49,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"user_id": key.UserID,
|
"user_id": key.UserID,
|
||||||
"device_id": key.DeviceID,
|
"device_id": key.DeviceID,
|
||||||
|
|
|
@ -43,4 +43,12 @@ type Database interface {
|
||||||
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||||
|
|
||||||
|
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
|
||||||
|
// their keys in some way.
|
||||||
|
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||||
|
|
||||||
|
// KeyChanges returns a list of user IDs who have modified their keys from the offset given.
|
||||||
|
// Returns the offset of the latest key change.
|
||||||
|
KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var keyChangesSchema = `
|
||||||
|
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
partition BIGINT NOT NULL,
|
||||||
|
log_offset BIGINT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||||
|
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||||
|
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||||
|
const upsertKeyChangeSQL = "" +
|
||||||
|
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
|
||||||
|
" VALUES ($1, $2, $3)" +
|
||||||
|
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
|
||||||
|
" DO UPDATE SET user_id = $3"
|
||||||
|
|
||||||
|
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||||
|
// take the max offset value as the latest offset.
|
||||||
|
const selectKeyChangesSQL = "" +
|
||||||
|
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id"
|
||||||
|
|
||||||
|
type keyChangesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertKeyChangeStmt *sql.Stmt
|
||||||
|
selectKeyChangesStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
|
s := &keyChangesStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(keyChangesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
|
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
|
ctx context.Context, partition int32, fromOffset int64,
|
||||||
|
) (userIDs []string, latestOffset int64, err error) {
|
||||||
|
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
var offset int64
|
||||||
|
if err := rows.Scan(&userID, &offset); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if offset > latestOffset {
|
||||||
|
latestOffset = offset
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -34,9 +34,14 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
kc, err := NewPostgresKeyChangesTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
|
KeyChangesTable: kc,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
OneTimeKeysTable tables.OneTimeKeys
|
OneTimeKeysTable tables.OneTimeKeys
|
||||||
DeviceKeysTable tables.DeviceKeys
|
DeviceKeysTable tables.DeviceKeys
|
||||||
|
KeyChangesTable tables.KeyChanges
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
|
||||||
})
|
})
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
|
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||||
|
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var keyChangesSchema = `
|
||||||
|
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
partition BIGINT NOT NULL,
|
||||||
|
offset BIGINT NOT NULL,
|
||||||
|
-- The key owner
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
UNIQUE (partition, offset)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||||
|
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||||
|
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||||
|
const upsertKeyChangeSQL = "" +
|
||||||
|
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||||
|
" VALUES ($1, $2, $3)" +
|
||||||
|
" ON CONFLICT (partition, offset)" +
|
||||||
|
" DO UPDATE SET user_id = $3"
|
||||||
|
|
||||||
|
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||||
|
// take the max offset value as the latest offset.
|
||||||
|
const selectKeyChangesSQL = "" +
|
||||||
|
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id"
|
||||||
|
|
||||||
|
type keyChangesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertKeyChangeStmt *sql.Stmt
|
||||||
|
selectKeyChangesStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
|
s := &keyChangesStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(keyChangesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
|
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
|
ctx context.Context, partition int32, fromOffset int64,
|
||||||
|
) (userIDs []string, latestOffset int64, err error) {
|
||||||
|
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
var offset int64
|
||||||
|
if err := rows.Scan(&userID, &offset); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if offset > latestOffset {
|
||||||
|
latestOffset = offset
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -37,9 +37,14 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
kc, err := NewSqliteKeyChangesTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
|
KeyChangesTable: kc,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ctx = context.Background()
|
||||||
|
|
||||||
|
func MustNotError(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("operation failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyChanges(t *testing.T) {
|
||||||
|
db, err := NewDatabase("file::memory:", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||||
|
userIDs, latest, err := db.KeyChanges(ctx, 0, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
|
}
|
||||||
|
if latest != 2 {
|
||||||
|
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||||
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyChangesNoDupes(t *testing.T) {
|
||||||
|
db, err := NewDatabase("file::memory:", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
||||||
|
userIDs, latest, err := db.KeyChanges(ctx, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
|
}
|
||||||
|
if latest != 2 {
|
||||||
|
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||||
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,3 +35,8 @@ type DeviceKeys interface {
|
||||||
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type KeyChanges interface {
|
||||||
|
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||||
|
SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue