Update all usages of tx.Stmt to sqlutil.TxStmt (#1423)

* Replace all usages of txn.Stmt with sqlutil.TxStmt

Signed-off-by: Sam Day <me@samcday.com>

* Fix sign off link in PR template.

Signed-off-by: Sam Day <me@samcday.com>

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
main
Sam 2020-09-24 12:10:14 +02:00 committed by GitHub
parent 60524f4b99
commit a6700331ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 55 additions and 36 deletions

View File

@ -3,4 +3,4 @@
<!-- Please read CONTRIBUTING.md before submitting your pull request --> <!-- Please read CONTRIBUTING.md before submitting your pull request -->
* [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md) * [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md)
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/CONTRIBUTING.md#sign-off) * [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off)

View File

@ -88,6 +88,14 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
return statement return statement
} }
// TxStmtContext behaves similarly to TxStmt, with support for also passing context.
func TxStmtContext(context context.Context, transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
if transaction != nil {
statement = transaction.StmtContext(context, statement)
}
return statement
}
// Hack of the century // Hack of the century
func QueryVariadic(count int) string { func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0) return QueryVariadicOffset(count, 0)

View File

@ -21,6 +21,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -125,7 +126,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
@ -151,7 +152,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {
@ -162,7 +163,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err return err
} }

View File

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -151,14 +152,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.
} }
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -180,14 +181,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{ return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON), algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err }, err

View File

@ -97,7 +97,7 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err return err
} }
@ -156,7 +156,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
@ -188,7 +188,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {

View File

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -153,14 +154,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
} }
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,14 +183,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -86,7 +87,7 @@ func (s *stateSnapshotStatements) InsertState(
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i]) nids[i] = int64(stateBlockNIDs[i])
} }
err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return return
} }

View File

@ -105,12 +105,12 @@ func (s *stateBlockStatements) BulkInsertStateData(
return 0, nil return 0, nil
} }
var stateBlockNID types.StateBlockNID var stateBlockNID types.StateBlockNID
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
for _, entry := range entries { for _, entry := range entries {
_, err = txn.Stmt(s.insertStateDataStmt).ExecContext( _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
ctx, ctx,
int64(stateBlockNID), int64(stateBlockNID),
int64(entry.EventTypeNID), int64(entry.EventTypeNID),

View File

@ -76,7 +76,7 @@ func (s *stateSnapshotStatements) InsertState(
if err != nil { if err != nil {
return return
} }
insertStmt := txn.Stmt(s.insertStateStmt) insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -81,7 +81,7 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return return
} }
@ -110,7 +110,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return return
} }

View File

@ -160,13 +160,13 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
return return
} }

View File

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -85,7 +86,7 @@ func (s *accountDataStatements) InsertAccountData(
if err != nil { if err != nil {
return return
} }
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
return return
} }
@ -147,7 +148,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
var nullableID sql.NullInt64 var nullableID sql.NullInt64
err = txn.Stmt(s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid { if nullableID.Valid {
id = nullableID.Int64 id = nullableID.Int64
} }

View File

@ -84,7 +84,7 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err return err
} }
@ -113,7 +113,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err return err
} }

View File

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -75,7 +76,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }

View File

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -99,7 +100,7 @@ func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := txn.Stmt(s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if appserviceID == "" {
@ -162,7 +163,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = txn.Stmt(stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return return

View File

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const profilesSchema = ` const profilesSchema = `
@ -84,7 +85,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View File

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -75,7 +77,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
_, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return err return err
} }

View File

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -104,9 +105,9 @@ func (s *accountsStatements) insertAccount(
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else { } else {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -163,7 +164,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = txn.Stmt(stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return return

View File

@ -87,7 +87,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
_, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return err return err
} }