diff --git a/clientapi/routing/rate_limiting.go b/clientapi/routing/rate_limiting.go index 16e3c056..9d3f817a 100644 --- a/clientapi/routing/rate_limiting.go +++ b/clientapi/routing/rate_limiting.go @@ -13,6 +13,7 @@ import ( type rateLimits struct { limits map[string]chan struct{} limitsMutex sync.RWMutex + cleanMutex sync.RWMutex enabled bool requestThreshold int64 cooloffDuration time.Duration @@ -38,6 +39,7 @@ func (l *rateLimits) clean() { // empty. If they are then we will close and delete them, // freeing up memory. time.Sleep(time.Second * 30) + l.cleanMutex.Lock() l.limitsMutex.Lock() for k, c := range l.limits { if len(c) == 0 { @@ -46,6 +48,7 @@ func (l *rateLimits) clean() { } } l.limitsMutex.Unlock() + l.cleanMutex.Unlock() } } @@ -55,12 +58,12 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse { return nil } - // Lock the map long enough to check for rate limiting. We hold it - // for longer here than we really need to but it makes sure that we - // also don't conflict with the cleaner goroutine which might clean - // up a channel after we have retrieved it otherwise. - l.limitsMutex.RLock() - defer l.limitsMutex.RUnlock() + // Take a read lock out on the cleaner mutex. The cleaner expects to + // be able to take a write lock, which isn't possible while there are + // readers, so this has the effect of blocking the cleaner goroutine + // from doing its work until there are no requests in flight. + l.cleanMutex.RLock() + defer l.cleanMutex.RUnlock() // First of all, work out if X-Forwarded-For was sent to us. If not // then we'll just use the IP address of the caller. @@ -69,12 +72,19 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse { caller = forwardedFor } - // Look up the caller's channel, if they have one. If they don't then - // let's create one. + // Look up the caller's channel, if they have one. + l.limitsMutex.RLock() rateLimit, ok := l.limits[caller] + l.limitsMutex.RUnlock() + + // If the caller doesn't have a channel, create one and write it + // back to the map. if !ok { - l.limits[caller] = make(chan struct{}, l.requestThreshold) - rateLimit = l.limits[caller] + rateLimit = make(chan struct{}, l.requestThreshold) + + l.limitsMutex.Lock() + l.limits[caller] = rateLimit + l.limitsMutex.Unlock() } // Check if the user has got free resource slots for this request.