diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index e9e117a7..57612908 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -231,13 +231,24 @@ func (oq *destinationQueue) backgroundSend() { // If we are backing off this server then wait for the // backoff duration to complete first, or until explicitly // told to retry. - if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp { + until, blacklisted := oq.statistics.BackoffInfo() + if blacklisted { // It's been suggested that we should give up because the backoff // has exceeded a maximum allowable value. Clean up the in-memory // buffers at this point. The PDU clean-up is already on a defer. log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) return } + if until != nil { + // We haven't backed off yet, so wait for the suggested amount of + // time. + duration := time.Until(*until) + log.Warnf("Backing off %q for %s", oq.destination, duration) + select { + case <-time.After(duration): + case <-oq.interruptBackoff: + } + } // If we have pending PDUs or EDUs then construct a transaction. if pendingPDUs || pendingEDUs { diff --git a/federationsender/statistics/statistics.go b/federationsender/statistics/statistics.go index 03ef64e9..b5fe7513 100644 --- a/federationsender/statistics/statistics.go +++ b/federationsender/statistics/statistics.go @@ -44,6 +44,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, + interrupt: make(chan struct{}), } s.servers[serverName] = server s.mutex.Unlock() @@ -68,6 +69,7 @@ type ServerStatistics struct { backoffStarted atomic.Bool // is the backoff started backoffUntil atomic.Value // time.Time until this backoff interval ends backoffCount atomic.Uint32 // number of times BackoffDuration has been called + interrupt chan struct{} // interrupts the backoff goroutine successCounter atomic.Uint32 // how many times have we succeeded? } @@ -76,15 +78,24 @@ func (s *ServerStatistics) duration(count uint32) time.Duration { return time.Second * time.Duration(math.Exp2(float64(count))) } +// cancel will interrupt the currently active backoff. +func (s *ServerStatistics) cancel() { + s.blacklisted.Store(false) + s.backoffUntil.Store(time.Time{}) + select { + case s.interrupt <- struct{}{}: + default: + } +} + // Success updates the server statistics with a new successful // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then // we will unblacklist it. func (s *ServerStatistics) Success() { - s.successCounter.Add(1) - s.backoffStarted.Store(false) + s.cancel() + s.successCounter.Inc() s.backoffCount.Store(0) - s.blacklisted.Store(false) if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -99,10 +110,30 @@ func (s *ServerStatistics) Success() { // whether we have blacklisted and therefore to give up. func (s *ServerStatistics) Failure() (time.Time, bool) { // If we aren't already backing off, this call will start - // a new backoff period. Reset the counter to 0 so that - // we backoff only for short periods of time to start with. + // a new backoff period. Increase the failure counter and + // start a goroutine which will wait out the backoff and + // unset the backoffStarted flag when done. if s.backoffStarted.CAS(false, true) { - s.backoffCount.Store(0) + if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist { + s.blacklisted.Store(true) + if s.statistics.DB != nil { + if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) + } + } + return time.Time{}, true + } + + go func() { + until, ok := s.backoffUntil.Load().(time.Time) + if ok { + select { + case <-time.After(time.Until(until)): + case <-s.interrupt: + } + } + s.backoffStarted.Store(false) + }() } // Check if we have blacklisted this node. @@ -136,53 +167,6 @@ func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) { return nil, s.blacklisted.Load() } -// BackoffIfRequired will block for as long as the current -// backoff requires, if needed. Otherwise it will do nothing. -// Returns the amount of time to backoff for and whether to give up or not. -func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) { - if started := s.backoffStarted.Load(); !started { - return 0, false - } - - // Work out if we should be blacklisting at this point. - count := s.backoffCount.Inc() - if count >= s.statistics.FailuresUntilBlacklist { - // We've exceeded the maximum amount of times we're willing - // to back off, which is probably in the region of hours by - // now. Mark the host as blacklisted and tell the caller to - // give up. - s.blacklisted.Store(true) - if s.statistics.DB != nil { - if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) - } - } - return 0, true - } - - // Work out when we should wait until. - duration := s.duration(count) - until := time.Now().Add(duration) - s.backoffUntil.Store(until) - - // Notify the destination queue that we're backing off now. - backingOff.Store(true) - defer backingOff.Store(false) - - // Work out how long we should be backing off for. - logrus.Warnf("Backing off %q for %s", s.serverName, duration) - - // Wait for either an interruption or for the backoff to - // complete. - select { - case <-interrupt: - logrus.Debugf("Interrupting backoff for %q", s.serverName) - case <-time.After(duration): - } - - return duration, false -} - // Blacklisted returns true if the server is blacklisted and false // otherwise. func (s *ServerStatistics) Blacklisted() bool { diff --git a/federationsender/statistics/statistics_test.go b/federationsender/statistics/statistics_test.go index 7e083de6..225350b6 100644 --- a/federationsender/statistics/statistics_test.go +++ b/federationsender/statistics/statistics_test.go @@ -4,8 +4,6 @@ import ( "math" "testing" "time" - - "go.uber.org/atomic" ) func TestBackoff(t *testing.T) { @@ -27,34 +25,30 @@ func TestBackoff(t *testing.T) { server.Failure() t.Logf("Backoff counter: %d", server.backoffCount.Load()) - backingOff := atomic.Bool{} // Now we're going to simulate backing off a few times to see // what happens. for i := uint32(1); i <= 10; i++ { - // Interrupt the backoff - it doesn't really matter if it - // completes but we will find out how long the backoff should - // have been. - interrupt := make(chan bool, 1) - close(interrupt) - - // Get the duration. - duration, blacklist := server.BackoffIfRequired(backingOff, interrupt) - // Register another failure for good measure. This should have no // side effects since a backoff is already in progress. If it does // then we'll fail. until, blacklisted := server.Failure() - if time.Until(until) > duration { - t.Fatal("Failure produced unexpected side effect when it shouldn't have") - } + + // Get the duration. + _, blacklist := server.BackoffInfo() + duration := time.Until(until).Round(time.Second) + + // Unset the backoff, or otherwise our next call will think that + // there's a backoff in progress and return the same result. + server.cancel() + server.backoffStarted.Store(false) // Check if we should be blacklisted by now. if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) } else if blacklist != blacklisted { - t.Fatalf("BackoffIfRequired and Failure returned different blacklist values") + t.Fatalf("BackoffInfo and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue