diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index aedaeab1..9ccfbace 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -262,15 +262,13 @@ func (oq *destinationQueue) backgroundSend() { // If we are backing off this server then wait for the // backoff duration to complete first, or until explicitly // told to retry. - if backoff, duration := oq.statistics.BackoffDuration(); backoff { - log.WithField("duration", duration).Debugf("Backing off %s", oq.destination) - oq.backingOff.Store(true) - select { - case <-time.After(duration): - case <-oq.interruptBackoff: - log.Debugf("Interrupting backoff for %q", oq.destination) - } - oq.backingOff.Store(false) + if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp { + // It's been suggested that we should give up because the backoff + // has exceeded a maximum allowable value. Clean up the in-memory + // buffers at this point. The PDU clean-up is already on a defer. + oq.cleanPendingInvites() + log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) + return } // If we have pending PDUs or EDUs then construct a transaction. @@ -278,24 +276,8 @@ func (oq *destinationQueue) backgroundSend() { // Try sending the next transaction and see what happens. transaction, terr := oq.nextTransaction() if terr != nil { - // We failed to send the transaction. - if giveUp := oq.statistics.Failure(); giveUp { - // It's been suggested that we should give up because the backoff - // has exceeded a maximum allowable value. Clean up the in-memory - // buffers at this point. The PDU clean-up is already on a defer. - oq.cleanPendingInvites() - log.Infof("Blacklisting %q due to errors", oq.destination) - return - } else { - // We haven't been told to give up terminally yet but we still have - // PDUs waiting to be sent. By sending a message into the wake chan, - // the next loop iteration will try processing these PDUs again, - // subject to the backoff. - select { - case oq.notifyPDUs <- true: - default: - } - } + // We failed to send the transaction. Mark it as a failure. + oq.statistics.Failure() } else if transaction { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. @@ -307,14 +289,8 @@ func (oq *destinationQueue) backgroundSend() { if len(oq.pendingInvites) > 0 { sent, ierr := oq.nextInvites(oq.pendingInvites) if ierr != nil { - // We failed to send the transaction so increase the - // backoff and give it another go shortly. - if giveUp := oq.statistics.Failure(); giveUp { - // It's been suggested that we should give up because - // the backoff has exceeded a maximum allowable value. - log.Infof("Blacklisting %q due to errors", oq.destination) - return - } + // We failed to send the transaction. Mark it as a failure. + oq.statistics.Failure() } else if sent > 0 { // If we successfully sent the invites then clear out // the pending invites. diff --git a/federationsender/statistics/statistics.go b/federationsender/statistics/statistics.go index 17dd896d..0dd8da20 100644 --- a/federationsender/statistics/statistics.go +++ b/federationsender/statistics/statistics.go @@ -65,8 +65,8 @@ type ServerStatistics struct { statistics *Statistics // serverName gomatrixserverlib.ServerName // blacklisted atomic.Bool // is the node blacklisted - backoffUntil atomic.Value // time.Time to wait until before sending requests - failCounter atomic.Uint32 // how many times have we failed? + backoffStarted atomic.Bool // is the backoff started + backoffCount atomic.Uint32 // number of times BackoffDuration has been called successCounter atomic.Uint32 // how many times have we succeeded? } @@ -76,55 +76,67 @@ type ServerStatistics struct { // we will unblacklist it. func (s *ServerStatistics) Success() { s.successCounter.Add(1) - s.failCounter.Store(0) + s.backoffStarted.Store(false) + s.backoffCount.Store(0) s.blacklisted.Store(false) - if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + if s.statistics.DB != nil { + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } } -// Failure marks a failure and works out when to backoff until. It -// returns true if the worker should give up altogether because of -// too many consecutive failures. At this point the host is marked -// as blacklisted. -func (s *ServerStatistics) Failure() bool { - // Increase the fail counter. - failCounter := s.failCounter.Add(1) +// Failure marks a failure and starts backing off if needed. +// The next call to BackoffIfRequired will do the right thing +// after this. +func (s *ServerStatistics) Failure() { + if s.backoffStarted.CAS(false, true) { + s.backoffCount.Store(0) + } +} - // Check that we haven't failed more times than is acceptable. - if failCounter >= s.statistics.FailuresUntilBlacklist { +// BackoffIfRequired will block for as long as the current +// backoff requires, if needed. Otherwise it will do nothing. +func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) { + if started := s.backoffStarted.Load(); !started { + return 0, false + } + + // Work out how many times we've backed off so far. + count := s.backoffCount.Inc() + duration := time.Second * time.Duration(math.Exp2(float64(count))) + + // Work out if we should be blacklisting at this point. + if count >= s.statistics.FailuresUntilBlacklist { // We've exceeded the maximum amount of times we're willing // to back off, which is probably in the region of hours by // now. Mark the host as blacklisted and tell the caller to // give up. s.blacklisted.Store(true) - if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) + if s.statistics.DB != nil { + if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) + } } - return true + return duration, true } - // We're still under the threshold so work out the exponential - // backoff based on how many times we have failed already. The - // worker goroutine will wait until this time before processing - // anything from the queue. - backoffSeconds := time.Second * time.Duration(math.Exp2(float64(failCounter))) - s.backoffUntil.Store( - time.Now().Add(backoffSeconds), - ) - return false -} + // Notify the destination queue that we're backing off now. + backingOff.Store(true) + defer backingOff.Store(false) -// BackoffDuration returns both a bool stating whether to wait, -// and then if true, a duration to wait for. -func (s *ServerStatistics) BackoffDuration() (bool, time.Duration) { - backoff, until := false, time.Second - if b, ok := s.backoffUntil.Load().(time.Time); ok { - if b.After(time.Now()) { - backoff, until = true, time.Until(b) - } + // Work out how long we should be backing off for. + logrus.Warnf("Backing off %q for %s", s.serverName, duration) + + // Wait for either an interruption or for the backoff to + // complete. + select { + case <-interrupt: + logrus.Debugf("Interrupting backoff for %q", s.serverName) + case <-time.After(duration): } - return backoff, until + + return duration, false } // Blacklisted returns true if the server is blacklisted and false diff --git a/federationsender/statistics/statistics_test.go b/federationsender/statistics/statistics_test.go new file mode 100644 index 00000000..9050662e --- /dev/null +++ b/federationsender/statistics/statistics_test.go @@ -0,0 +1,60 @@ +package statistics + +import ( + "math" + "testing" + "time" + + "go.uber.org/atomic" +) + +func TestBackoff(t *testing.T) { + stats := Statistics{ + FailuresUntilBlacklist: 5, + } + server := ServerStatistics{ + statistics: &stats, + serverName: "test.com", + } + + // Start by checking that counting successes works. + server.Success() + if successes := server.SuccessCount(); successes != 1 { + t.Fatalf("Expected success count 1, got %d", successes) + } + + // Register a failure. + server.Failure() + + t.Logf("Backoff counter: %d", server.backoffCount.Load()) + backingOff := atomic.Bool{} + + // Now we're going to simulate backing off a few times to see + // what happens. + for i := uint32(1); i <= 10; i++ { + // Interrupt the backoff - it doesn't really matter if it + // completes but we will find out how long the backoff should + // have been. + interrupt := make(chan bool, 1) + close(interrupt) + + // Get the duration. + duration, blacklist := server.BackoffIfRequired(backingOff, interrupt) + + // Check if we should be blacklisted by now. + if i > stats.FailuresUntilBlacklist { + if !blacklist { + t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) + } else { + t.Logf("Backoff %d is blacklisted as expected", i) + continue + } + } + + // Check if the duration is what we expect. + t.Logf("Backoff %d is for %s", i, duration) + if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { + t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) + } + } +}