Skip to content

Commit f5a7973

Browse files
committed
removed duplication
Signed-off-by: Sanskarzz <sanskar.gur@gmail.com>
1 parent f870b62 commit f5a7973

2 files changed

Lines changed: 37 additions & 113 deletions

File tree

pkg/vmcp/ratelimit/factory/middleware.go

Lines changed: 37 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,17 @@ package factory
66

77
import (
88
"context"
9-
"encoding/json"
109
"fmt"
11-
"log/slog"
12-
"math"
1310
"net/http"
14-
"os"
15-
"time"
1611

17-
"github.com/redis/go-redis/v9"
18-
19-
"github.com/stacklok/toolhive/pkg/auth"
20-
mcpparser "github.com/stacklok/toolhive/pkg/mcp"
12+
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
13+
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
2114
"github.com/stacklok/toolhive/pkg/ratelimit"
2215
ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
16+
transporttypes "github.com/stacklok/toolhive/pkg/transport/types"
2317
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
2418
)
2519

26-
const redisPingTimeout = 5 * time.Second
27-
2820
// Config contains the vMCP rate-limit middleware inputs.
2921
type Config struct {
3022
Namespace string
@@ -35,7 +27,7 @@ type Config struct {
3527

3628
// NewMiddleware creates Redis-backed rate-limit middleware for vMCP.
3729
func NewMiddleware(
38-
ctx context.Context,
30+
_ context.Context,
3931
cfg Config,
4032
) (func(http.Handler) http.Handler, func(context.Context) error, error) {
4133
if cfg.RateLimiting == nil {
@@ -48,88 +40,51 @@ func NewMiddleware(
4840
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address")
4941
}
5042

51-
client := redis.NewClient(&redis.Options{
52-
Addr: cfg.SessionStorage.Address,
53-
DB: int(cfg.SessionStorage.DB),
54-
Password: os.Getenv(vmcpconfig.RedisPasswordEnvVar),
43+
middlewareConfig, err := transporttypes.NewMiddlewareConfig(ratelimit.MiddlewareType, ratelimit.MiddlewareParams{
44+
Namespace: cfg.Namespace,
45+
ServerName: cfg.ServerName,
46+
Config: cfg.RateLimiting,
47+
RedisAddr: cfg.SessionStorage.Address,
48+
RedisDB: cfg.SessionStorage.DB,
5549
})
56-
57-
pingCtx, cancel := context.WithTimeout(ctx, redisPingTimeout)
58-
defer cancel()
59-
if err := client.Ping(pingCtx).Err(); err != nil {
60-
_ = client.Close()
61-
return nil, nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w",
62-
cfg.SessionStorage.Address, err)
50+
if err != nil {
51+
return nil, nil, fmt.Errorf("failed to create rate limit middleware config: %w", err)
6352
}
6453

65-
limiter, err := ratelimit.NewLimiter(client, cfg.Namespace, cfg.ServerName, cfg.RateLimiting)
66-
if err != nil {
67-
_ = client.Close()
68-
return nil, nil, fmt.Errorf("failed to create rate limiter: %w", err)
54+
runner := &captureRunner{}
55+
if err := ratelimit.CreateMiddleware(middlewareConfig, runner); err != nil {
56+
return nil, nil, err
57+
}
58+
if runner.middleware == nil {
59+
return nil, nil, fmt.Errorf("rate limit middleware factory did not register middleware")
6960
}
7061

7162
cleanup := func(context.Context) error {
72-
return client.Close()
63+
return runner.middleware.Close()
7364
}
74-
return rateLimitHandler(limiter), cleanup, nil
65+
return runner.middleware.Handler(), cleanup, nil
7566
}
7667

77-
func rateLimitHandler(limiter ratelimit.Limiter) func(http.Handler) http.Handler {
78-
return func(next http.Handler) http.Handler {
79-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80-
parsed := mcpparser.GetParsedMCPRequest(r.Context())
81-
if parsed == nil || parsed.Method != "tools/call" {
82-
next.ServeHTTP(w, r)
83-
return
84-
}
68+
type captureRunner struct {
69+
middleware transporttypes.Middleware
70+
}
8571

86-
var userID string
87-
if identity, ok := auth.IdentityFromContext(r.Context()); ok {
88-
userID = identity.Subject
89-
}
90-
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID)
91-
if err != nil {
92-
slog.Warn("rate limit check failed, allowing request", "error", err)
93-
next.ServeHTTP(w, r)
94-
return
95-
}
96-
if !decision.Allowed {
97-
writeRateLimited(w, parsed.ID, decision.RetryAfter)
98-
return
99-
}
100-
next.ServeHTTP(w, r)
101-
})
102-
}
72+
func (r *captureRunner) AddMiddleware(_ string, middleware transporttypes.Middleware) {
73+
r.middleware = middleware
10374
}
10475

105-
func writeRateLimited(w http.ResponseWriter, requestID any, retryAfter time.Duration) {
106-
retrySeconds := int(math.Ceil(retryAfter.Seconds()))
107-
w.Header().Set("Content-Type", "application/json")
108-
w.Header().Set("Retry-After", fmt.Sprintf("%d", retrySeconds))
109-
w.WriteHeader(http.StatusTooManyRequests)
110-
//nolint:gosec // G104: writing a static JSON error response to an HTTP client
111-
_, _ = w.Write(rateLimitedBody(requestID, retryAfter))
76+
func (*captureRunner) SetAuthInfoHandler(http.Handler) {}
77+
78+
func (*captureRunner) SetPrometheusHandler(http.Handler) {}
79+
80+
func (*captureRunner) GetConfig() transporttypes.RunnerConfig {
81+
return nil
11282
}
11383

114-
func rateLimitedBody(requestID any, retryAfter time.Duration) []byte {
115-
retrySeconds := math.Ceil(retryAfter.Seconds())
116-
resp := map[string]any{
117-
"jsonrpc": "2.0",
118-
"error": map[string]any{
119-
"code": ratelimit.CodeRateLimited,
120-
"message": ratelimit.MessageRateLimited,
121-
"data": map[string]any{
122-
"retryAfterSeconds": retrySeconds,
123-
},
124-
},
125-
"id": requestID,
126-
}
127-
data, err := json.Marshal(resp)
128-
if err != nil {
129-
return []byte(fmt.Sprintf(
130-
`{"jsonrpc":"2.0","error":{"code":-32029,"message":"Rate limit exceeded","data":{"retryAfterSeconds":%.0f}},"id":null}`,
131-
retrySeconds,
132-
))
133-
}
134-
return data
84+
func (*captureRunner) GetUpstreamTokenReader() upstreamtoken.TokenReader {
85+
return nil
86+
}
87+
88+
func (*captureRunner) GetKeyProvider() keys.PublicKeyProvider {
89+
return nil
13590
}

pkg/vmcp/ratelimit/factory/middleware_test.go

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,6 @@ import (
2323
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
2424
)
2525

26-
type recordingLimiter struct {
27-
toolName string
28-
userID string
29-
}
30-
31-
func (r *recordingLimiter) Allow(_ context.Context, toolName, userID string) (*ratelimit.Decision, error) {
32-
r.toolName = toolName
33-
r.userID = userID
34-
return &ratelimit.Decision{Allowed: true}, nil
35-
}
36-
3726
func TestNewMiddlewareDisabledWithoutConfig(t *testing.T) {
3827
t.Parallel()
3928

@@ -169,26 +158,6 @@ func TestRateLimitMiddlewareUsesPostAggregationToolNames(t *testing.T) {
169158
assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code)
170159
}
171160

172-
func TestRateLimitHandlerPassesParsedResourceIDAndUserID(t *testing.T) {
173-
t.Parallel()
174-
175-
recorder := &recordingLimiter{}
176-
handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
177-
w.WriteHeader(http.StatusOK)
178-
}))
179-
180-
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
181-
req = withParsedMCPRequest(req, "tools/call", "backend_a_echo", 1)
182-
req = withIdentity(req, "alice@example.com")
183-
w := httptest.NewRecorder()
184-
185-
handler.ServeHTTP(w, req)
186-
187-
assert.Equal(t, http.StatusOK, w.Code)
188-
assert.Equal(t, "backend_a_echo", recorder.toolName)
189-
assert.Equal(t, "alice@example.com", recorder.userID)
190-
}
191-
192161
func newTestRateLimitHandler(t *testing.T, cfg *ratelimittypes.RateLimitConfig) http.Handler {
193162
t.Helper()
194163

0 commit comments

Comments
 (0)