diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index eed58c15..c0dd42c5 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -110,9 +110,10 @@ func (s *inviteEventsStatements) InsertInviteEvent( } func (s *inviteEventsStatements) DeleteInviteEvent( - ctx context.Context, inviteEventID string, + ctx context.Context, txn *sql.Tx, inviteEventID string, ) (sp types.StreamPosition, err error) { - err = s.deleteInviteEventStmt.QueryRowContext(ctx, inviteEventID).Scan(&sp) + stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt) + err = stmt.QueryRowContext(ctx, inviteEventID).Scan(&sp) return } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 37793ba2..3388473a 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -139,8 +139,8 @@ func (d *Database) GetStateEventsForRoom( func (d *Database) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ) (sp types.StreamPosition, err error) { - _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent) + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) return nil }) return @@ -151,8 +151,8 @@ func (d *Database) AddInviteEvent( func (d *Database) RetireInviteEvent( ctx context.Context, inviteEventID string, ) (sp types.StreamPosition, err error) { - _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID) + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) return nil }) return @@ -422,7 +422,7 @@ func (d *Database) addPDUDeltaToResponse( wantFullState bool, res *types.Response, ) (joinedRoomIDs []string, err error) { - txn, err := d.DB.BeginTx(context.TODO(), &txReadOnlySnapshot) // TODO: check mattn/go-sqlite3#764 + txn, err := d.DB.BeginTx(ctx, &txReadOnlySnapshot) if err != nil { return nil, err } @@ -606,7 +606,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( // a consistent view of the database throughout. This includes extracting the sync position. // This does have the unfortunate side-effect that all the matrixy logic resides in this function, // but it's better to not hide the fact that this is being done in a transaction. - txn, err := d.DB.BeginTx(context.TODO(), &txReadOnlySnapshot) // TODO: check mattn/go-sqlite3#764 + txn, err := d.DB.BeginTx(ctx, &txReadOnlySnapshot) if err != nil { return } @@ -1063,15 +1063,6 @@ func (d *Database) SendToDeviceUpdatesWaiting( return count > 0, nil } -func (d *Database) AddSendToDeviceEvent( - ctx context.Context, txn *sql.Tx, - userID, deviceID, content string, -) error { - return d.SendToDevice.InsertSendToDeviceMessage( - ctx, txn, userID, deviceID, content, - ) -} - func (d *Database) StoreNewSendForDeviceMessage( ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, ) (types.StreamPosition, error) { @@ -1082,7 +1073,7 @@ func (d *Database) StoreNewSendForDeviceMessage( // Delegate the database write task to the SendToDeviceWriter. It'll guarantee // that we don't lock the table for writes in more than one place. err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.AddSendToDeviceEvent( + return d.SendToDevice.InsertSendToDeviceMessage( ctx, txn, userID, deviceID, string(j), ) }) diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 7da86683..1a36ad40 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -117,13 +117,14 @@ func (s *inviteEventsStatements) InsertInviteEvent( } func (s *inviteEventsStatements) DeleteInviteEvent( - ctx context.Context, inviteEventID string, + ctx context.Context, txn *sql.Tx, inviteEventID string, ) (types.StreamPosition, error) { - streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil) + streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) if err != nil { return streamPos, err } - _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) + stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt) + _, err = stmt.ExecContext(ctx, streamPos, inviteEventID) return streamPos, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2ff229cb..38f6d848 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -32,7 +32,7 @@ type AccountData interface { type Invites interface { InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) - DeleteInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error) + DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error) // SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value // for the room. SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]gomatrixserverlib.HeaderedEvent, retired map[string]gomatrixserverlib.HeaderedEvent, err error) diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index aee8db6e..f9430c24 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -18,8 +18,6 @@ import ( "context" "database/sql" "encoding/json" - - "github.com/matrix-org/dendrite/internal/sqlutil" ) const accountDataSchema = ` @@ -51,15 +49,13 @@ const selectAccountDataByTypeSQL = "" + type accountDataStatements struct { db *sql.DB - writer sqlutil.Writer insertAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +func (s *accountDataStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = writer _, err = db.Exec(accountDataSchema) if err != nil { return @@ -78,11 +74,9 @@ func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err func (s *accountDataStatements) insertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, -) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) - return err - }) +) error { + _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + return err } func (s *accountDataStatements) selectAccountData( diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 83b90668..798a6de9 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -59,7 +58,6 @@ const selectNewNumericLocalpartSQL = "" + type accountsStatements struct { db *sql.DB - writer sqlutil.Writer insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt @@ -67,9 +65,9 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = writer + _, err = db.Exec(accountsSchema) if err != nil { return @@ -99,15 +97,12 @@ func (s *accountsStatements) insertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - var err error - if appserviceID == "" { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) - } else { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) - } - return err - }) + var err error + if appserviceID == "" { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } if err != nil { return nil, err } diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index 1ec45e03..4eeaf037 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -53,7 +53,6 @@ const selectProfilesBySearchSQL = "" + type profilesStatements struct { db *sql.DB - writer sqlutil.Writer insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt @@ -61,9 +60,8 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +func (s *profilesStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = writer _, err = db.Exec(profilesSchema) if err != nil { return @@ -88,11 +86,9 @@ func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err err func (s *profilesStatements) insertProfile( ctx context.Context, txn *sql.Tx, localpart string, -) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") - return err - }) +) error { + _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + return err } func (s *profilesStatements) selectProfileByLocalpart( @@ -109,16 +105,18 @@ func (s *profilesStatements) selectProfileByLocalpart( } func (s *profilesStatements) setAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) + stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) + _, err = stmt.ExecContext(ctx, avatarURL, localpart) return } func (s *profilesStatements) setDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) + stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) + _, err = stmt.ExecContext(ctx, displayName, localpart) return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4f45f754..46106297 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -64,16 +64,16 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = partitions.Prepare(db, d.writer, "account"); err != nil { return nil, err } - if err = d.accounts.prepare(db, d.writer, serverName); err != nil { + if err = d.accounts.prepare(db, serverName); err != nil { return nil, err } - if err = d.profiles.prepare(db, d.writer); err != nil { + if err = d.profiles.prepare(db); err != nil { return nil, err } - if err = d.accountDatas.prepare(db, d.writer); err != nil { + if err = d.accountDatas.prepare(db); err != nil { return nil, err } - if err = d.threepids.prepare(db, d.writer); err != nil { + if err = d.threepids.prepare(db); err != nil { return nil, err } return d, nil @@ -109,7 +109,9 @@ func (d *Database) SetAvatarURL( ) error { d.profilesMu.Lock() defer d.profilesMu.Unlock() - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) + }) } // SetDisplayName updates the display name of the profile associated with the given @@ -119,7 +121,9 @@ func (d *Database) SetDisplayName( ) error { d.profilesMu.Lock() defer d.profilesMu.Unlock() - return d.profiles.setDisplayName(ctx, localpart, displayName) + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.profiles.setDisplayName(ctx, txn, localpart, displayName) + }) } // CreateGuestAccount makes a new guest account and creates an empty profile @@ -136,7 +140,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er defer d.profilesMu.Unlock() defer d.accountDatasMu.Unlock() defer d.accountsMu.Unlock() - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { @@ -162,7 +166,7 @@ func (d *Database) CreateAccount( defer d.profilesMu.Unlock() defer d.accountDatasMu.Unlock() defer d.accountsMu.Unlock() - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err }) @@ -214,7 +218,7 @@ func (d *Database) SaveAccountData( ) error { d.accountDatasMu.Lock() defer d.accountDatasMu.Unlock() - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -267,7 +271,7 @@ func (d *Database) SaveThreePIDAssociation( ) (err error) { d.threepidsMu.Lock() defer d.threepidsMu.Unlock() - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, ) @@ -292,7 +296,9 @@ func (d *Database) RemoveThreePIDAssociation( ) (err error) { d.threepidsMu.Lock() defer d.threepidsMu.Unlock() - return d.threepids.deleteThreePID(ctx, threepid, medium) + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.threepids.deleteThreePID(ctx, txn, threepid, medium) + }) } // GetLocalpartForThreePID looks up the localpart associated with a given third-party diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 230978fe..43112d38 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -54,16 +54,14 @@ const deleteThreePIDSQL = "" + type threepidStatements struct { db *sql.DB - writer sqlutil.Writer selectLocalpartForThreePIDStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt insertThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +func (s *threepidStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = writer _, err = db.Exec(threepidSchema) if err != nil { return @@ -122,18 +120,14 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) insertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err := stmt.ExecContext(ctx, threepid, medium, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + return err } func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) - _, err := stmt.ExecContext(ctx, threepid, medium) - return err - }) + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) + return err }