@@ -52,6 +52,17 @@ type rateLimitMiddleware struct {
5252 client redis.UniversalClient
5353}
5454
55+ // ToolNameResolver resolves the rate-limit tool name from a parsed MCP request.
56+ type ToolNameResolver func (* mcp.ParsedMCPRequest ) string
57+
58+ // DefaultToolNameResolver uses the parsed MCP resource ID as the rate-limit tool name.
59+ func DefaultToolNameResolver (parsed * mcp.ParsedMCPRequest ) string {
60+ if parsed == nil {
61+ return ""
62+ }
63+ return parsed .ResourceID
64+ }
65+
5566// Handler returns the middleware function used by the proxy.
5667func (m * rateLimitMiddleware ) Handler () types.MiddlewareFunction {
5768 return m .handler
@@ -99,16 +110,19 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
99110 }
100111
101112 mw := & rateLimitMiddleware {
102- handler : rateLimitHandler (limiter ),
113+ handler : NewMiddleware (limiter , nil ),
103114 client : client ,
104115 }
105116 runner .AddMiddleware (MiddlewareType , mw )
106117 return nil
107118}
108119
109- // rateLimitHandler returns a middleware function that enforces rate limits
120+ // NewMiddleware returns a middleware function that enforces rate limits
110121// on tools/call requests.
111- func rateLimitHandler (limiter Limiter ) types.MiddlewareFunction {
122+ func NewMiddleware (limiter Limiter , resolveToolName ToolNameResolver ) types.MiddlewareFunction {
123+ if resolveToolName == nil {
124+ resolveToolName = DefaultToolNameResolver
125+ }
112126 return func (next http.Handler ) http.Handler {
113127 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
114128 // Rate limits only apply to parsed tools/call requests.
@@ -127,7 +141,7 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
127141 if identity , ok := auth .IdentityFromContext (r .Context ()); ok {
128142 userID = identity .Subject
129143 }
130- decision , err := limiter .Allow (r .Context (), parsed . ResourceID , userID )
144+ decision , err := limiter .Allow (r .Context (), resolveToolName ( parsed ) , userID )
131145 if err != nil {
132146 slog .Warn ("rate limit check failed, allowing request" , "error" , err )
133147 next .ServeHTTP (w , r )
@@ -142,6 +156,11 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
142156 }
143157}
144158
159+ // rateLimitHandler returns the default rate-limit middleware used by tests and legacy callers.
160+ func rateLimitHandler (limiter Limiter ) types.MiddlewareFunction {
161+ return NewMiddleware (limiter , nil )
162+ }
163+
145164// writeRateLimited writes an HTTP 429 response with a JSON-RPC error body.
146165func writeRateLimited (w http.ResponseWriter , requestID any , retryAfter time.Duration ) {
147166 retrySeconds := int (math .Ceil (retryAfter .Seconds ()))
0 commit comments