diff --git a/appservice/storage/postgres/appservice_events_table.go b/appservice/storage/postgres/appservice_events_table.go index d33a83b1..a95be6b8 100644 --- a/appservice/storage/postgres/appservice_events_table.go +++ b/appservice/storage/postgres/appservice_events_table.go @@ -111,19 +111,19 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( eventsRemaining bool, err error, ) { - // Retrieve events from the database. Unsuccessfully sent events first - eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID) - if err != nil { - return - } defer func() { - err = eventRows.Close() if err != nil { log.WithFields(log.Fields{ "appservice": applicationServiceID, }).WithError(err).Fatalf("appservice unable to select new events to send") } }() + // Retrieve events from the database. Unsuccessfully sent events first + eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID) + if err != nil { + return + } + defer checkNamedErr(eventRows.Close, &err) events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) if err != nil { return @@ -132,6 +132,13 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( return } +// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil +func checkNamedErr(fn func() error, err *error) { + if e := fn(); e != nil && *err == nil { + *err = e + } +} + func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { // Get current time for use in calculating event age nowMilli := time.Now().UnixNano() / int64(time.Millisecond) diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go index 5dfb72f6..34b4859e 100644 --- a/appservice/storage/sqlite3/appservice_events_table.go +++ b/appservice/storage/sqlite3/appservice_events_table.go @@ -116,19 +116,19 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( eventsRemaining bool, err error, ) { - // Retrieve events from the database. Unsuccessfully sent events first - eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID) - if err != nil { - return - } defer func() { - err = eventRows.Close() if err != nil { log.WithFields(log.Fields{ "appservice": applicationServiceID, }).WithError(err).Fatalf("appservice unable to select new events to send") } }() + // Retrieve events from the database. Unsuccessfully sent events first + eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID) + if err != nil { + return + } + defer checkNamedErr(eventRows.Close, &err) events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) if err != nil { return @@ -137,6 +137,13 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( return } +// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil +func checkNamedErr(fn func() error, err *error) { + if e := fn(); e != nil && *err == nil { + *err = e + } +} + func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { // Get current time for use in calculating event age nowMilli := time.Now().UnixNano() / int64(time.Millisecond) diff --git a/appservice/workers/transaction_scheduler.go b/appservice/workers/transaction_scheduler.go index 63ec58aa..b1735841 100644 --- a/appservice/workers/transaction_scheduler.go +++ b/appservice/workers/transaction_scheduler.go @@ -101,6 +101,9 @@ func worker(db storage.Database, ws types.ApplicationServiceWorkerState) { // Backoff if the application service does not respond err = send(client, ws.AppService, txnID, transactionJSON) if err != nil { + log.WithFields(log.Fields{ + "appservice": ws.AppService.ID, + }).WithError(err).Error("unable to send event") // Backoff backoff(&ws, err) continue @@ -207,7 +210,7 @@ func send( appservice config.ApplicationService, txnID int, transaction []byte, -) error { +) (err error) { // PUT a transaction to our AS // https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken)) @@ -220,14 +223,7 @@ func send( if err != nil { return err } - defer func() { - err := resp.Body.Close() - if err != nil { - log.WithFields(log.Fields{ - "appservice": appservice.ID, - }).WithError(err).Error("unable to close response body from application service") - } - }() + defer checkNamedErr(resp.Body.Close, &err) // Check the AS received the events correctly if resp.StatusCode != http.StatusOK { @@ -237,3 +233,10 @@ func send( return nil } + +// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil +func checkNamedErr(fn func() error, err *error) { + if e := fn(); e != nil && *err == nil { + *err = e + } +} diff --git a/internal/sqlutil/partition_offset_table.go b/internal/sqlutil/partition_offset_table.go index be079442..e19a092f 100644 --- a/internal/sqlutil/partition_offset_table.go +++ b/internal/sqlutil/partition_offset_table.go @@ -18,8 +18,6 @@ import ( "context" "database/sql" "strings" - - "github.com/matrix-org/util" ) // A PartitionOffset is the offset into a partition of the input log. @@ -99,26 +97,28 @@ func (s *PartitionOffsetStatements) SetPartitionOffset( // selectPartitionOffsets returns all the partition offsets for the given topic. func (s *PartitionOffsetStatements) selectPartitionOffsets( ctx context.Context, topic string, -) ([]PartitionOffset, error) { +) (results []PartitionOffset, err error) { rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic) if err != nil { return nil, err } - defer func() { - err2 := rows.Close() - if err2 != nil { - util.GetLogger(ctx).WithError(err2).Error("selectPartitionOffsets: rows.close() failed") - } - }() - var results []PartitionOffset + defer checkNamedErr(rows.Close, &err) for rows.Next() { var offset PartitionOffset - if err := rows.Scan(&offset.Partition, &offset.Offset); err != nil { + if err = rows.Scan(&offset.Partition, &offset.Offset); err != nil { return nil, err } results = append(results, offset) } - return results, rows.Err() + err = rows.Err() + return results, err +} + +// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil +func checkNamedErr(fn func() error, err *error) { + if e := fn(); e != nil && *err == nil { + *err = e + } } // UpsertPartitionOffset updates or inserts the partition offset for the given topic. diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index d296c418..1d2825d5 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -38,9 +38,18 @@ type Transaction interface { // was applied correctly. For example, 'database is locked' errors in sqlite will happen here. func EndTransaction(txn Transaction, succeeded *bool) error { if *succeeded { - return txn.Commit() // nolint: errcheck + return txn.Commit() } else { - return txn.Rollback() // nolint: errcheck + return txn.Rollback() + } +} + +// EndTransactionWithCheck ends a transaction and overwrites the error pointer if its value was nil. +// If the transaction succeeded then it is committed, otherwise it is rolledback. +// Designed to be used with defer (see EndTransaction otherwise). +func EndTransactionWithCheck(txn Transaction, succeeded *bool, err *error) { + if e := EndTransaction(txn, succeeded); e != nil && *err == nil { + *err = e } } @@ -53,12 +62,7 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err) } succeeded := false - defer func() { - err2 := EndTransaction(txn, &succeeded) - if err == nil && err2 != nil { // failed to commit/rollback - err = err2 - } - }() + defer EndTransactionWithCheck(txn, &succeeded, &err) err = fn(txn) if err != nil { diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 3be5218d..f11a78d7 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -60,12 +60,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } succeeded := false - defer func() { - txerr := sqlutil.EndTransaction(updater, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) u := latestEventsUpdater{ ctx: ctx, diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index e14d2223..ee95e09d 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -95,9 +95,8 @@ func (s *OutputKeyChangeEventConsumer) updateOffset(msg *sarama.ConsumerMessage) } func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { - defer func() { - s.updateOffset(msg) - }() + defer s.updateOffset(msg) + var output api.DeviceMessage if err := json.Unmarshal(msg.Value, &output); err != nil { // If the message was invalid, log it and move on to the next message in the stream diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 699a6647..4031dc74 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -455,13 +455,8 @@ func (d *Database) addPDUDeltaToResponse( if err != nil { return nil, err } - var succeeded bool - defer func() { - txerr := sqlutil.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() + succeeded := false + defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request @@ -641,13 +636,8 @@ func (d *Database) getResponseWithPDUsForCompleteSync( if err != nil { return } - var succeeded bool - defer func() { - txerr := sqlutil.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() + succeeded := false + defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) // Get the current sync position which we will base the sync response on. toPos, err = d.syncPositionTx(ctx, txn)