Use TransactionWriter in SQLite keyserver (#1239)
* Use TransactionWriter in SQLite keyserver * Fix keyserver storage testsmain
parent
22f028e141
commit
15dc1f4d03
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
@ -54,6 +55,7 @@ const selectMaxStreamForUserSQL = "" +
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
@ -63,6 +65,7 @@ type deviceKeysStatements struct {
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
s := &deviceKeysStatements{
|
s := &deviceKeysStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err := db.Exec(deviceKeysSchema)
|
_, err := db.Exec(deviceKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -141,6 +144,7 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
|
@ -151,4 +155,5 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,6 +52,7 @@ const selectKeyChangesSQL = "" +
|
||||||
|
|
||||||
type keyChangesStatements struct {
|
type keyChangesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
upsertKeyChangeStmt *sql.Stmt
|
upsertKeyChangeStmt *sql.Stmt
|
||||||
selectKeyChangesStmt *sql.Stmt
|
selectKeyChangesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
@ -58,6 +60,7 @@ type keyChangesStatements struct {
|
||||||
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
s := &keyChangesStatements{
|
s := &keyChangesStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err := db.Exec(keyChangesSchema)
|
_, err := db.Exec(keyChangesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -73,8 +76,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||||
return err
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyChangesStatements) SelectKeyChanges(
|
func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
|
|
|
@ -60,6 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
upsertKeysStmt *sql.Stmt
|
upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
selectKeysCountStmt *sql.Stmt
|
selectKeysCountStmt *sql.Stmt
|
||||||
|
@ -70,6 +71,7 @@ type oneTimeKeysStatements struct {
|
||||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
s := &oneTimeKeysStatements{
|
s := &oneTimeKeysStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err := db.Exec(oneTimeKeysSchema)
|
_, err := db.Exec(oneTimeKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -150,7 +152,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
|
||||||
UserID: keys.UserID,
|
UserID: keys.UserID,
|
||||||
KeyCount: make(map[string]int),
|
KeyCount: make(map[string]int),
|
||||||
}
|
}
|
||||||
return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
||||||
|
@ -183,14 +185,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
) (map[string]json.RawMessage, error) {
|
) (map[string]json.RawMessage, error) {
|
||||||
var keyID string
|
var keyID string
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
|
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
return map[string]json.RawMessage{
|
return map[string]json.RawMessage{
|
||||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
}, err
|
}, err
|
||||||
|
|
|
@ -2,6 +2,10 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -11,6 +15,21 @@ import (
|
||||||
|
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
|
||||||
|
func MustCreateDatabase(t *testing.T) (Database, func()) {
|
||||||
|
tmpfile, err := ioutil.TempFile("", "keyserver_storage_test")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("Database %s", tmpfile.Name())
|
||||||
|
db, err := NewDatabase(fmt.Sprintf("file://%s", tmpfile.Name()), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||||
|
}
|
||||||
|
return db, func() {
|
||||||
|
os.Remove(tmpfile.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func MustNotError(t *testing.T, err error) {
|
func MustNotError(t *testing.T, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -20,10 +39,8 @@ func MustNotError(t *testing.T, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeyChanges(t *testing.T) {
|
func TestKeyChanges(t *testing.T) {
|
||||||
db, err := NewDatabase("file::memory:", nil)
|
db, clean := MustCreateDatabase(t)
|
||||||
if err != nil {
|
defer clean()
|
||||||
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
||||||
}
|
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||||
|
@ -40,10 +57,8 @@ func TestKeyChanges(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeyChangesNoDupes(t *testing.T) {
|
func TestKeyChangesNoDupes(t *testing.T) {
|
||||||
db, err := NewDatabase("file::memory:", nil)
|
db, clean := MustCreateDatabase(t)
|
||||||
if err != nil {
|
defer clean()
|
||||||
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
||||||
}
|
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
||||||
|
@ -60,10 +75,8 @@ func TestKeyChangesNoDupes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeyChangesUpperLimit(t *testing.T) {
|
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||||
db, err := NewDatabase("file::memory:", nil)
|
db, clean := MustCreateDatabase(t)
|
||||||
if err != nil {
|
defer clean()
|
||||||
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
||||||
}
|
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||||
|
@ -82,10 +95,9 @@ func TestKeyChangesUpperLimit(t *testing.T) {
|
||||||
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||||
// and that they are returned correctly when querying for device keys.
|
// and that they are returned correctly when querying for device keys.
|
||||||
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
db, err := NewDatabase("file::memory:", nil)
|
var err error
|
||||||
if err != nil {
|
db, clean := MustCreateDatabase(t)
|
||||||
t.Fatalf("Failed to NewDatabase: %s", err)
|
defer clean()
|
||||||
}
|
|
||||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||||
msgs := []api.DeviceMessage{
|
msgs := []api.DeviceMessage{
|
||||||
|
|
Loading…
Reference in New Issue