diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index e80c51df1..9756daf50 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -7,6 +7,7 @@ import ( "errors" "math" "net/http" + "strconv" "sync" "time" @@ -14,11 +15,25 @@ import ( "golang.org/x/time/rate" ) +// Rate limit response headers set by stores that implement RateLimiterStoreContext. +const ( + HeaderXRateLimitLimit = "X-RateLimit-Limit" + HeaderXRateLimitRemaining = "X-RateLimit-Remaining" +) + // RateLimiterStore is the interface to be implemented by custom stores. type RateLimiterStore interface { Allow(identifier string) (bool, error) } +// RateLimiterStoreContext is an optional interface a RateLimiterStore may implement. +// When the configured store implements it, the rate limiter calls AllowContext +// (with the request context) instead of Allow, allowing the store to set response +// headers such as Retry-After or X-RateLimit-* on the allow/deny decision. +type RateLimiterStoreContext interface { + AllowContext(c *echo.Context, identifier string) (bool, error) +} + // RateLimiterConfig defines the configuration for the rate limiter type RateLimiterConfig struct { Skipper Skipper @@ -136,7 +151,14 @@ func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return config.ErrorHandler(c, err) } - if allow, allowErr := config.Store.Allow(identifier); !allow { + var allow bool + var allowErr error + if sc, ok := config.Store.(RateLimiterStoreContext); ok { + allow, allowErr = sc.AllowContext(c, identifier) + } else { + allow, allowErr = config.Store.Allow(identifier) + } + if !allow { return config.DenyHandler(c, identifier, allowErr) } return next(c) @@ -232,7 +254,22 @@ var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ // Allow implements RateLimiterStore.Allow func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { + _, allowed := store.allow(identifier) + return allowed, nil +} + +// AllowContext implements RateLimiterStoreContext: it makes the allow/deny decision +// and sets the X-RateLimit-* (and Retry-After when denied) response headers. +func (store *RateLimiterMemoryStore) AllowContext(c *echo.Context, identifier string) (bool, error) { + limiter, allowed := store.allow(identifier) + store.setRateLimitHeaders(c, limiter, allowed) + return allowed, nil +} + +func (store *RateLimiterMemoryStore) allow(identifier string) (*rate.Limiter, bool) { store.mutex.Lock() + defer store.mutex.Unlock() + limiter, exists := store.visitors[identifier] if !exists { limiter = new(Visitor) @@ -244,9 +281,26 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { if now.Sub(store.lastCleanup) > store.expiresIn { store.cleanupStaleVisitors(now) } - allowed := limiter.AllowN(now, 1) - store.mutex.Unlock() - return allowed, nil + return limiter.Limiter, limiter.AllowN(now, 1) +} + +func (store *RateLimiterMemoryStore) setRateLimitHeaders(c *echo.Context, limiter *rate.Limiter, allowed bool) { + header := c.Response().Header() + header.Set(HeaderXRateLimitLimit, strconv.Itoa(store.burst)) + + remaining := int(math.Floor(limiter.Tokens())) + if remaining < 0 { + remaining = 0 + } + header.Set(HeaderXRateLimitRemaining, strconv.Itoa(remaining)) + + if !allowed { + reservation := limiter.ReserveN(store.timeNow(), 1) + if delay := reservation.Delay(); delay > 0 { + header.Set(echo.HeaderRetryAfter, strconv.Itoa(int(math.Ceil(delay.Seconds())))) + } + reservation.Cancel() + } } /* diff --git a/middleware/rate_limiter_context_test.go b/middleware/rate_limiter_context_test.go new file mode 100644 index 000000000..629c01e47 --- /dev/null +++ b/middleware/rate_limiter_context_test.go @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" +) + +// ctxAwareStore implements both Allow and the optional AllowContext. AllowContext +// gives the store the request context so it can set response headers (e.g. +// Retry-After / X-RateLimit-*) — see #2961. +type ctxAwareStore struct { + allowCalled bool + ctxAllowCalled bool + allow bool +} + +func (s *ctxAwareStore) Allow(identifier string) (bool, error) { + s.allowCalled = true + return s.allow, nil +} + +func (s *ctxAwareStore) AllowContext(c *echo.Context, identifier string) (bool, error) { + s.ctxAllowCalled = true + c.Response().Header().Set("Retry-After", "42") + return s.allow, nil +} + +// When the store implements AllowContext, the middleware must call it instead of +// Allow, so the store can set rate-limit headers on the response. +func TestRateLimiter_storeAllowContextIsPreferred(t *testing.T) { + e := echo.New() + store := &ctxAwareStore{allow: true} + mw := RateLimiterWithConfig(RateLimiterConfig{ + Store: store, + IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil }, + }) + handler := mw(func(c *echo.Context) error { return c.String(http.StatusOK, "ok") }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NoError(t, handler(c)) + assert.True(t, store.ctxAllowCalled, "AllowContext should be called when implemented") + assert.False(t, store.allowCalled, "Allow should not be called when AllowContext is implemented") + assert.Equal(t, "42", rec.Header().Get("Retry-After"), "store should be able to set headers via the context") +} + +// The built-in memory store implements AllowContext, so it sets X-RateLimit-Limit / +// X-RateLimit-Remaining on every request and Retry-After when the limit is hit (#2961). +func TestRateLimiterMemoryStore_AllowContextSetsHeaders(t *testing.T) { + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + e := echo.New() + e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "ok") }, + RateLimiterWithConfig(RateLimiterConfig{ + Store: store, + IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil }, + })) + + do := func() *httptest.ResponseRecorder { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec + } + + // Burst of 3: each allowed request advertises the limit and decreasing remaining. + for i := 0; i < 3; i++ { + rec := do() + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit)) + assert.Equal(t, strconv.Itoa(2-i), rec.Header().Get(HeaderXRateLimitRemaining)) + assert.Empty(t, rec.Header().Get(echo.HeaderRetryAfter)) + } + + // 4th request is denied: 429, remaining 0, and a Retry-After hint. + rec := do() + assert.Equal(t, http.StatusTooManyRequests, rec.Code) + assert.Equal(t, "0", rec.Header().Get(HeaderXRateLimitRemaining)) + assert.NotEmpty(t, rec.Header().Get(echo.HeaderRetryAfter)) +}