Skip to content

Commit c9477eb

Browse files
vishrclaude
andauthored
feat(middleware): optional RateLimiterStoreContext for response headers (#2961) (#3007)
* feat(middleware): optional RateLimiterStoreContext for response headers (#2961) Adds an optional RateLimiterStoreContext interface. When the configured store implements AllowContext(c, identifier), the rate limiter calls it instead of Allow, giving the store access to the request context so it can set response headers such as Retry-After / X-RateLimit-*. Fully backward compatible: stores implementing only Allow are unchanged. This is the optional-interface approach proposed by the maintainer in the issue thread; it does not alter the existing Allow interface or the built-in store. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * feat(middleware): set X-RateLimit-* / Retry-After from built-in store (#2961) Implements AllowContext on RateLimiterMemoryStore so the default store sets X-RateLimit-Limit, X-RateLimit-Remaining, and (on deny) Retry-After headers out of the box — mirroring the v4 PR #2985 by @leno23 on the v5 line. Allow() is refactored to share an internal allow() with AllowContext; the optional RateLimiterStoreContext interface (added earlier in this PR) routes the middleware to AllowContext when the store implements it. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent b1d65e4 commit c9477eb

2 files changed

Lines changed: 147 additions & 4 deletions

File tree

middleware/rate_limiter.go

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,33 @@ import (
77
"errors"
88
"math"
99
"net/http"
10+
"strconv"
1011
"sync"
1112
"time"
1213

1314
"github.com/labstack/echo/v5"
1415
"golang.org/x/time/rate"
1516
)
1617

18+
// Rate limit response headers set by stores that implement RateLimiterStoreContext.
19+
const (
20+
HeaderXRateLimitLimit = "X-RateLimit-Limit"
21+
HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
22+
)
23+
1724
// RateLimiterStore is the interface to be implemented by custom stores.
1825
type RateLimiterStore interface {
1926
Allow(identifier string) (bool, error)
2027
}
2128

29+
// RateLimiterStoreContext is an optional interface a RateLimiterStore may implement.
30+
// When the configured store implements it, the rate limiter calls AllowContext
31+
// (with the request context) instead of Allow, allowing the store to set response
32+
// headers such as Retry-After or X-RateLimit-* on the allow/deny decision.
33+
type RateLimiterStoreContext interface {
34+
AllowContext(c *echo.Context, identifier string) (bool, error)
35+
}
36+
2237
// RateLimiterConfig defines the configuration for the rate limiter
2338
type RateLimiterConfig struct {
2439
Skipper Skipper
@@ -136,7 +151,14 @@ func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
136151
return config.ErrorHandler(c, err)
137152
}
138153

139-
if allow, allowErr := config.Store.Allow(identifier); !allow {
154+
var allow bool
155+
var allowErr error
156+
if sc, ok := config.Store.(RateLimiterStoreContext); ok {
157+
allow, allowErr = sc.AllowContext(c, identifier)
158+
} else {
159+
allow, allowErr = config.Store.Allow(identifier)
160+
}
161+
if !allow {
140162
return config.DenyHandler(c, identifier, allowErr)
141163
}
142164
return next(c)
@@ -232,7 +254,22 @@ var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
232254

233255
// Allow implements RateLimiterStore.Allow
234256
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
257+
_, allowed := store.allow(identifier)
258+
return allowed, nil
259+
}
260+
261+
// AllowContext implements RateLimiterStoreContext: it makes the allow/deny decision
262+
// and sets the X-RateLimit-* (and Retry-After when denied) response headers.
263+
func (store *RateLimiterMemoryStore) AllowContext(c *echo.Context, identifier string) (bool, error) {
264+
limiter, allowed := store.allow(identifier)
265+
store.setRateLimitHeaders(c, limiter, allowed)
266+
return allowed, nil
267+
}
268+
269+
func (store *RateLimiterMemoryStore) allow(identifier string) (*rate.Limiter, bool) {
235270
store.mutex.Lock()
271+
defer store.mutex.Unlock()
272+
236273
limiter, exists := store.visitors[identifier]
237274
if !exists {
238275
limiter = new(Visitor)
@@ -244,9 +281,26 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
244281
if now.Sub(store.lastCleanup) > store.expiresIn {
245282
store.cleanupStaleVisitors(now)
246283
}
247-
allowed := limiter.AllowN(now, 1)
248-
store.mutex.Unlock()
249-
return allowed, nil
284+
return limiter.Limiter, limiter.AllowN(now, 1)
285+
}
286+
287+
func (store *RateLimiterMemoryStore) setRateLimitHeaders(c *echo.Context, limiter *rate.Limiter, allowed bool) {
288+
header := c.Response().Header()
289+
header.Set(HeaderXRateLimitLimit, strconv.Itoa(store.burst))
290+
291+
remaining := int(math.Floor(limiter.Tokens()))
292+
if remaining < 0 {
293+
remaining = 0
294+
}
295+
header.Set(HeaderXRateLimitRemaining, strconv.Itoa(remaining))
296+
297+
if !allowed {
298+
reservation := limiter.ReserveN(store.timeNow(), 1)
299+
if delay := reservation.Delay(); delay > 0 {
300+
header.Set(echo.HeaderRetryAfter, strconv.Itoa(int(math.Ceil(delay.Seconds()))))
301+
}
302+
reservation.Cancel()
303+
}
250304
}
251305

252306
/*
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// SPDX-License-Identifier: MIT
2+
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3+
4+
package middleware
5+
6+
import (
7+
"net/http"
8+
"net/http/httptest"
9+
"strconv"
10+
"testing"
11+
12+
"github.com/labstack/echo/v5"
13+
"github.com/stretchr/testify/assert"
14+
)
15+
16+
// ctxAwareStore implements both Allow and the optional AllowContext. AllowContext
17+
// gives the store the request context so it can set response headers (e.g.
18+
// Retry-After / X-RateLimit-*) — see #2961.
19+
type ctxAwareStore struct {
20+
allowCalled bool
21+
ctxAllowCalled bool
22+
allow bool
23+
}
24+
25+
func (s *ctxAwareStore) Allow(identifier string) (bool, error) {
26+
s.allowCalled = true
27+
return s.allow, nil
28+
}
29+
30+
func (s *ctxAwareStore) AllowContext(c *echo.Context, identifier string) (bool, error) {
31+
s.ctxAllowCalled = true
32+
c.Response().Header().Set("Retry-After", "42")
33+
return s.allow, nil
34+
}
35+
36+
// When the store implements AllowContext, the middleware must call it instead of
37+
// Allow, so the store can set rate-limit headers on the response.
38+
func TestRateLimiter_storeAllowContextIsPreferred(t *testing.T) {
39+
e := echo.New()
40+
store := &ctxAwareStore{allow: true}
41+
mw := RateLimiterWithConfig(RateLimiterConfig{
42+
Store: store,
43+
IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil },
44+
})
45+
handler := mw(func(c *echo.Context) error { return c.String(http.StatusOK, "ok") })
46+
47+
req := httptest.NewRequest(http.MethodGet, "/", nil)
48+
rec := httptest.NewRecorder()
49+
c := e.NewContext(req, rec)
50+
51+
assert.NoError(t, handler(c))
52+
assert.True(t, store.ctxAllowCalled, "AllowContext should be called when implemented")
53+
assert.False(t, store.allowCalled, "Allow should not be called when AllowContext is implemented")
54+
assert.Equal(t, "42", rec.Header().Get("Retry-After"), "store should be able to set headers via the context")
55+
}
56+
57+
// The built-in memory store implements AllowContext, so it sets X-RateLimit-Limit /
58+
// X-RateLimit-Remaining on every request and Retry-After when the limit is hit (#2961).
59+
func TestRateLimiterMemoryStore_AllowContextSetsHeaders(t *testing.T) {
60+
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
61+
e := echo.New()
62+
e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "ok") },
63+
RateLimiterWithConfig(RateLimiterConfig{
64+
Store: store,
65+
IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil },
66+
}))
67+
68+
do := func() *httptest.ResponseRecorder {
69+
req := httptest.NewRequest(http.MethodGet, "/", nil)
70+
rec := httptest.NewRecorder()
71+
e.ServeHTTP(rec, req)
72+
return rec
73+
}
74+
75+
// Burst of 3: each allowed request advertises the limit and decreasing remaining.
76+
for i := 0; i < 3; i++ {
77+
rec := do()
78+
assert.Equal(t, http.StatusOK, rec.Code)
79+
assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit))
80+
assert.Equal(t, strconv.Itoa(2-i), rec.Header().Get(HeaderXRateLimitRemaining))
81+
assert.Empty(t, rec.Header().Get(echo.HeaderRetryAfter))
82+
}
83+
84+
// 4th request is denied: 429, remaining 0, and a Retry-After hint.
85+
rec := do()
86+
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
87+
assert.Equal(t, "0", rec.Header().Get(HeaderXRateLimitRemaining))
88+
assert.NotEmpty(t, rec.Header().Get(echo.HeaderRetryAfter))
89+
}

0 commit comments

Comments
 (0)