Skip to content

Commit 5b90ba9

Browse files
committed
vMCP rate-limit middleware wiring
Signed-off-by: Sanskarzz <sanskar.gur@gmail.com>
1 parent 08baf2d commit 5b90ba9

6 files changed

Lines changed: 728 additions & 79 deletions

File tree

pkg/ratelimit/middleware.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ type rateLimitMiddleware struct {
5252
client redis.UniversalClient
5353
}
5454

55+
// ToolNameResolver resolves the rate-limit tool name from a parsed MCP request.
56+
type ToolNameResolver func(*mcp.ParsedMCPRequest) string
57+
58+
// DefaultToolNameResolver uses the parsed MCP resource ID as the rate-limit tool name.
59+
func DefaultToolNameResolver(parsed *mcp.ParsedMCPRequest) string {
60+
if parsed == nil {
61+
return ""
62+
}
63+
return parsed.ResourceID
64+
}
65+
5566
// Handler returns the middleware function used by the proxy.
5667
func (m *rateLimitMiddleware) Handler() types.MiddlewareFunction {
5768
return m.handler
@@ -99,16 +110,19 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
99110
}
100111

101112
mw := &rateLimitMiddleware{
102-
handler: rateLimitHandler(limiter),
113+
handler: NewMiddleware(limiter, nil),
103114
client: client,
104115
}
105116
runner.AddMiddleware(MiddlewareType, mw)
106117
return nil
107118
}
108119

109-
// rateLimitHandler returns a middleware function that enforces rate limits
120+
// NewMiddleware returns a middleware function that enforces rate limits
110121
// on tools/call requests.
111-
func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
122+
func NewMiddleware(limiter Limiter, resolveToolName ToolNameResolver) types.MiddlewareFunction {
123+
if resolveToolName == nil {
124+
resolveToolName = DefaultToolNameResolver
125+
}
112126
return func(next http.Handler) http.Handler {
113127
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
114128
// Rate limits only apply to parsed tools/call requests.
@@ -127,7 +141,7 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
127141
if identity, ok := auth.IdentityFromContext(r.Context()); ok {
128142
userID = identity.Subject
129143
}
130-
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID)
144+
decision, err := limiter.Allow(r.Context(), resolveToolName(parsed), userID)
131145
if err != nil {
132146
slog.Warn("rate limit check failed, allowing request", "error", err)
133147
next.ServeHTTP(w, r)
@@ -142,6 +156,11 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
142156
}
143157
}
144158

159+
// rateLimitHandler returns the default rate-limit middleware used by tests and legacy callers.
160+
func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
161+
return NewMiddleware(limiter, nil)
162+
}
163+
145164
// writeRateLimited writes an HTTP 429 response with a JSON-RPC error body.
146165
func writeRateLimited(w http.ResponseWriter, requestID any, retryAfter time.Duration) {
147166
retrySeconds := int(math.Ceil(retryAfter.Seconds()))

pkg/vmcp/cli/serve.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
376376

377377
serverCfg := &vmcpserver.Config{
378378
Name: vmcpCfg.Name,
379+
Namespace: vmcpNamespace(),
379380
Version: versions.Version,
380381
GroupRef: vmcpCfg.Group,
381382
Host: cfg.Host,
@@ -394,6 +395,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
394395
OptimizerConfig: optCfg,
395396
SessionFactory: sessionFactory,
396397
SessionStorage: vmcpCfg.SessionStorage,
398+
RateLimiting: vmcpCfg.RateLimiting,
397399
}
398400

399401
// Assign Watcher only when backendWatcher is non-nil. A typed nil
@@ -529,6 +531,14 @@ func generateQuickModeConfig(groupRef string) (*config.Config, error) {
529531
return cfg, nil
530532
}
531533

534+
func vmcpNamespace() string {
535+
namespace := os.Getenv("VMCP_NAMESPACE")
536+
if namespace == "" {
537+
return "local"
538+
}
539+
return namespace
540+
}
541+
532542
// loadAuthServerConfig loads the auth server RunConfig from a sibling file
533543
// alongside the main config. The operator serializes authserver.RunConfig as a
534544
// separate ConfigMap key (authserver-config.yaml).

pkg/vmcp/server/ratelimit.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package server
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"net/http"
10+
"os"
11+
"time"
12+
13+
"github.com/redis/go-redis/v9"
14+
15+
mcpparser "github.com/stacklok/toolhive/pkg/mcp"
16+
"github.com/stacklok/toolhive/pkg/ratelimit"
17+
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
18+
)
19+
20+
const rateLimitRedisPingTimeout = 5 * time.Second
21+
22+
func (s *Server) buildRateLimitMiddleware(
23+
ctx context.Context,
24+
) (func(http.Handler) http.Handler, func(context.Context) error, error) {
25+
if s.config.RateLimiting == nil {
26+
return nil, nil, nil
27+
}
28+
if s.config.SessionStorage == nil || s.config.SessionStorage.Provider != "redis" {
29+
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage")
30+
}
31+
if s.config.SessionStorage.Address == "" {
32+
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address")
33+
}
34+
35+
client := redis.NewClient(&redis.Options{
36+
Addr: s.config.SessionStorage.Address,
37+
DB: int(s.config.SessionStorage.DB),
38+
Password: os.Getenv(vmcpconfig.RedisPasswordEnvVar),
39+
})
40+
41+
pingCtx, cancel := context.WithTimeout(ctx, rateLimitRedisPingTimeout)
42+
defer cancel()
43+
if err := client.Ping(pingCtx).Err(); err != nil {
44+
_ = client.Close()
45+
return nil, nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w",
46+
s.config.SessionStorage.Address, err)
47+
}
48+
49+
limiter, err := ratelimit.NewLimiter(client, s.config.Namespace, s.config.Name, s.config.RateLimiting)
50+
if err != nil {
51+
_ = client.Close()
52+
return nil, nil, fmt.Errorf("failed to create rate limiter: %w", err)
53+
}
54+
55+
cleanup := func(context.Context) error {
56+
return client.Close()
57+
}
58+
return ratelimit.NewMiddleware(limiter, s.rateLimitToolName), cleanup, nil
59+
}
60+
61+
func (s *Server) rateLimitToolName(parsed *mcpparser.ParsedMCPRequest) string {
62+
if parsed == nil {
63+
return ""
64+
}
65+
toolName := parsed.ResourceID
66+
if !s.optimizerEnabled() || toolName != "call_tool" {
67+
return toolName
68+
}
69+
if parsed.Arguments == nil {
70+
return toolName
71+
}
72+
innerToolName, ok := parsed.Arguments["tool_name"].(string)
73+
if !ok || innerToolName == "" {
74+
return toolName
75+
}
76+
return innerToolName
77+
}
78+
79+
func (s *Server) optimizerEnabled() bool {
80+
return s.config.OptimizerConfig != nil || s.config.OptimizerFactory != nil
81+
}

0 commit comments

Comments
 (0)