@@ -6,25 +6,17 @@ package factory
66
77import (
88 "context"
9- "encoding/json"
109 "fmt"
11- "log/slog"
12- "math"
1310 "net/http"
14- "os"
15- "time"
1611
17- "github.com/redis/go-redis/v9"
18-
19- "github.com/stacklok/toolhive/pkg/auth"
20- mcpparser "github.com/stacklok/toolhive/pkg/mcp"
12+ "github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
13+ "github.com/stacklok/toolhive/pkg/authserver/server/keys"
2114 "github.com/stacklok/toolhive/pkg/ratelimit"
2215 ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
16+ transporttypes "github.com/stacklok/toolhive/pkg/transport/types"
2317 vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
2418)
2519
26- const redisPingTimeout = 5 * time .Second
27-
2820// Config contains the vMCP rate-limit middleware inputs.
2921type Config struct {
3022 Namespace string
@@ -35,7 +27,7 @@ type Config struct {
3527
3628// NewMiddleware creates Redis-backed rate-limit middleware for vMCP.
3729func NewMiddleware (
38- ctx context.Context ,
30+ _ context.Context ,
3931 cfg Config ,
4032) (func (http.Handler ) http.Handler , func (context.Context ) error , error ) {
4133 if cfg .RateLimiting == nil {
@@ -48,88 +40,51 @@ func NewMiddleware(
4840 return nil , nil , fmt .Errorf ("rate limiting requires Redis session storage address" )
4941 }
5042
51- client := redis .NewClient (& redis.Options {
52- Addr : cfg .SessionStorage .Address ,
53- DB : int (cfg .SessionStorage .DB ),
54- Password : os .Getenv (vmcpconfig .RedisPasswordEnvVar ),
43+ middlewareConfig , err := transporttypes .NewMiddlewareConfig (ratelimit .MiddlewareType , ratelimit.MiddlewareParams {
44+ Namespace : cfg .Namespace ,
45+ ServerName : cfg .ServerName ,
46+ Config : cfg .RateLimiting ,
47+ RedisAddr : cfg .SessionStorage .Address ,
48+ RedisDB : cfg .SessionStorage .DB ,
5549 })
56-
57- pingCtx , cancel := context .WithTimeout (ctx , redisPingTimeout )
58- defer cancel ()
59- if err := client .Ping (pingCtx ).Err (); err != nil {
60- _ = client .Close ()
61- return nil , nil , fmt .Errorf ("rate limit middleware: failed to connect to Redis at %s: %w" ,
62- cfg .SessionStorage .Address , err )
50+ if err != nil {
51+ return nil , nil , fmt .Errorf ("failed to create rate limit middleware config: %w" , err )
6352 }
6453
65- limiter , err := ratelimit .NewLimiter (client , cfg .Namespace , cfg .ServerName , cfg .RateLimiting )
66- if err != nil {
67- _ = client .Close ()
68- return nil , nil , fmt .Errorf ("failed to create rate limiter: %w" , err )
54+ runner := & captureRunner {}
55+ if err := ratelimit .CreateMiddleware (middlewareConfig , runner ); err != nil {
56+ return nil , nil , err
57+ }
58+ if runner .middleware == nil {
59+ return nil , nil , fmt .Errorf ("rate limit middleware factory did not register middleware" )
6960 }
7061
7162 cleanup := func (context.Context ) error {
72- return client .Close ()
63+ return runner . middleware .Close ()
7364 }
74- return rateLimitHandler ( limiter ), cleanup , nil
65+ return runner . middleware . Handler ( ), cleanup , nil
7566}
7667
77- func rateLimitHandler (limiter ratelimit.Limiter ) func (http.Handler ) http.Handler {
78- return func (next http.Handler ) http.Handler {
79- return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
80- parsed := mcpparser .GetParsedMCPRequest (r .Context ())
81- if parsed == nil || parsed .Method != "tools/call" {
82- next .ServeHTTP (w , r )
83- return
84- }
68+ type captureRunner struct {
69+ middleware transporttypes.Middleware
70+ }
8571
86- var userID string
87- if identity , ok := auth .IdentityFromContext (r .Context ()); ok {
88- userID = identity .Subject
89- }
90- decision , err := limiter .Allow (r .Context (), parsed .ResourceID , userID )
91- if err != nil {
92- slog .Warn ("rate limit check failed, allowing request" , "error" , err )
93- next .ServeHTTP (w , r )
94- return
95- }
96- if ! decision .Allowed {
97- writeRateLimited (w , parsed .ID , decision .RetryAfter )
98- return
99- }
100- next .ServeHTTP (w , r )
101- })
102- }
72+ func (r * captureRunner ) AddMiddleware (_ string , middleware transporttypes.Middleware ) {
73+ r .middleware = middleware
10374}
10475
105- func writeRateLimited (w http.ResponseWriter , requestID any , retryAfter time.Duration ) {
106- retrySeconds := int (math .Ceil (retryAfter .Seconds ()))
107- w .Header ().Set ("Content-Type" , "application/json" )
108- w .Header ().Set ("Retry-After" , fmt .Sprintf ("%d" , retrySeconds ))
109- w .WriteHeader (http .StatusTooManyRequests )
110- //nolint:gosec // G104: writing a static JSON error response to an HTTP client
111- _ , _ = w .Write (rateLimitedBody (requestID , retryAfter ))
76+ func (* captureRunner ) SetAuthInfoHandler (http.Handler ) {}
77+
78+ func (* captureRunner ) SetPrometheusHandler (http.Handler ) {}
79+
80+ func (* captureRunner ) GetConfig () transporttypes.RunnerConfig {
81+ return nil
11282}
11383
114- func rateLimitedBody (requestID any , retryAfter time.Duration ) []byte {
115- retrySeconds := math .Ceil (retryAfter .Seconds ())
116- resp := map [string ]any {
117- "jsonrpc" : "2.0" ,
118- "error" : map [string ]any {
119- "code" : ratelimit .CodeRateLimited ,
120- "message" : ratelimit .MessageRateLimited ,
121- "data" : map [string ]any {
122- "retryAfterSeconds" : retrySeconds ,
123- },
124- },
125- "id" : requestID ,
126- }
127- data , err := json .Marshal (resp )
128- if err != nil {
129- return []byte (fmt .Sprintf (
130- `{"jsonrpc":"2.0","error":{"code":-32029,"message":"Rate limit exceeded","data":{"retryAfterSeconds":%.0f}},"id":null}` ,
131- retrySeconds ,
132- ))
133- }
134- return data
84+ func (* captureRunner ) GetUpstreamTokenReader () upstreamtoken.TokenReader {
85+ return nil
86+ }
87+
88+ func (* captureRunner ) GetKeyProvider () keys.PublicKeyProvider {
89+ return nil
13590}
0 commit comments