Don't overwrite global err before return (#1293)

Signed-off-by: Olivier Charvin <git@olivier.pfad.fr>
main
oliverpool 2020-08-25 14:11:52 +02:00 committed by GitHub
parent c8b873abc8
commit a4db43e096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 69 additions and 64 deletions

View File

@ -111,19 +111,19 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
eventsRemaining bool, eventsRemaining bool,
err error, err error,
) { ) {
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer func() { defer func() {
err = eventRows.Close()
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": applicationServiceID, "appservice": applicationServiceID,
}).WithError(err).Fatalf("appservice unable to select new events to send") }).WithError(err).Fatalf("appservice unable to select new events to send")
} }
}() }()
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
if err != nil { if err != nil {
return return
@ -132,6 +132,13 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)

View File

@ -116,19 +116,19 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
eventsRemaining bool, eventsRemaining bool,
err error, err error,
) { ) {
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer func() { defer func() {
err = eventRows.Close()
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": applicationServiceID, "appservice": applicationServiceID,
}).WithError(err).Fatalf("appservice unable to select new events to send") }).WithError(err).Fatalf("appservice unable to select new events to send")
} }
}() }()
// Retrieve events from the database. Unsuccessfully sent events first
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
if err != nil {
return
}
defer checkNamedErr(eventRows.Close, &err)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
if err != nil { if err != nil {
return return
@ -137,6 +137,13 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
return return
} }
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age // Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)

View File

@ -101,6 +101,9 @@ func worker(db storage.Database, ws types.ApplicationServiceWorkerState) {
// Backoff if the application service does not respond // Backoff if the application service does not respond
err = send(client, ws.AppService, txnID, transactionJSON) err = send(client, ws.AppService, txnID, transactionJSON)
if err != nil { if err != nil {
log.WithFields(log.Fields{
"appservice": ws.AppService.ID,
}).WithError(err).Error("unable to send event")
// Backoff // Backoff
backoff(&ws, err) backoff(&ws, err)
continue continue
@ -207,7 +210,7 @@ func send(
appservice config.ApplicationService, appservice config.ApplicationService,
txnID int, txnID int,
transaction []byte, transaction []byte,
) error { ) (err error) {
// PUT a transaction to our AS // PUT a transaction to our AS
// https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid // https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid
address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken)) address := fmt.Sprintf("%s/transactions/%d?access_token=%s", appservice.URL, txnID, url.QueryEscape(appservice.HSToken))
@ -220,14 +223,7 @@ func send(
if err != nil { if err != nil {
return err return err
} }
defer func() { defer checkNamedErr(resp.Body.Close, &err)
err := resp.Body.Close()
if err != nil {
log.WithFields(log.Fields{
"appservice": appservice.ID,
}).WithError(err).Error("unable to close response body from application service")
}
}()
// Check the AS received the events correctly // Check the AS received the events correctly
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
@ -237,3 +233,10 @@ func send(
return nil return nil
} }
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}

View File

@ -18,8 +18,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"strings" "strings"
"github.com/matrix-org/util"
) )
// A PartitionOffset is the offset into a partition of the input log. // A PartitionOffset is the offset into a partition of the input log.
@ -99,26 +97,28 @@ func (s *PartitionOffsetStatements) SetPartitionOffset(
// selectPartitionOffsets returns all the partition offsets for the given topic. // selectPartitionOffsets returns all the partition offsets for the given topic.
func (s *PartitionOffsetStatements) selectPartitionOffsets( func (s *PartitionOffsetStatements) selectPartitionOffsets(
ctx context.Context, topic string, ctx context.Context, topic string,
) ([]PartitionOffset, error) { ) (results []PartitionOffset, err error) {
rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic) rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer checkNamedErr(rows.Close, &err)
err2 := rows.Close()
if err2 != nil {
util.GetLogger(ctx).WithError(err2).Error("selectPartitionOffsets: rows.close() failed")
}
}()
var results []PartitionOffset
for rows.Next() { for rows.Next() {
var offset PartitionOffset var offset PartitionOffset
if err := rows.Scan(&offset.Partition, &offset.Offset); err != nil { if err = rows.Scan(&offset.Partition, &offset.Offset); err != nil {
return nil, err return nil, err
} }
results = append(results, offset) results = append(results, offset)
} }
return results, rows.Err() err = rows.Err()
return results, err
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
} }
// UpsertPartitionOffset updates or inserts the partition offset for the given topic. // UpsertPartitionOffset updates or inserts the partition offset for the given topic.

View File

@ -38,9 +38,18 @@ type Transaction interface {
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here. // was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
func EndTransaction(txn Transaction, succeeded *bool) error { func EndTransaction(txn Transaction, succeeded *bool) error {
if *succeeded { if *succeeded {
return txn.Commit() // nolint: errcheck return txn.Commit()
} else { } else {
return txn.Rollback() // nolint: errcheck return txn.Rollback()
}
}
// EndTransactionWithCheck ends a transaction and overwrites the error pointer if its value was nil.
// If the transaction succeeded then it is committed, otherwise it is rolledback.
// Designed to be used with defer (see EndTransaction otherwise).
func EndTransactionWithCheck(txn Transaction, succeeded *bool, err *error) {
if e := EndTransaction(txn, succeeded); e != nil && *err == nil {
*err = e
} }
} }
@ -53,12 +62,7 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err) return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err)
} }
succeeded := false succeeded := false
defer func() { defer EndTransactionWithCheck(txn, &succeeded, &err)
err2 := EndTransaction(txn, &succeeded)
if err == nil && err2 != nil { // failed to commit/rollback
err = err2
}
}()
err = fn(txn) err = fn(txn)
if err != nil { if err != nil {

View File

@ -60,12 +60,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
} }
succeeded := false succeeded := false
defer func() { defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
txerr := sqlutil.EndTransaction(updater, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
u := latestEventsUpdater{ u := latestEventsUpdater{
ctx: ctx, ctx: ctx,

View File

@ -95,9 +95,8 @@ func (s *OutputKeyChangeEventConsumer) updateOffset(msg *sarama.ConsumerMessage)
} }
func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
defer func() { defer s.updateOffset(msg)
s.updateOffset(msg)
}()
var output api.DeviceMessage var output api.DeviceMessage
if err := json.Unmarshal(msg.Value, &output); err != nil { if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream

View File

@ -455,13 +455,8 @@ func (d *Database) addPDUDeltaToResponse(
if err != nil { if err != nil {
return nil, err return nil, err
} }
var succeeded bool succeeded := false
defer func() { defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
txerr := sqlutil.EndTransaction(txn, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
@ -641,13 +636,8 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
if err != nil { if err != nil {
return return
} }
var succeeded bool succeeded := false
defer func() { defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
txerr := sqlutil.EndTransaction(txn, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
// Get the current sync position which we will base the sync response on. // Get the current sync position which we will base the sync response on.
toPos, err = d.syncPositionTx(ctx, txn) toPos, err = d.syncPositionTx(ctx, txn)