Skip to content

Commit f3cf1e4

Browse files
committed
improve test coverage
Signed-off-by: Sanskarzz <sanskar.gur@gmail.com>
1 parent 691c37e commit f3cf1e4

3 files changed

Lines changed: 258 additions & 0 deletions

File tree

pkg/ratelimit/middleware_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/alicebob/miniredis/v2"
1617
"github.com/stretchr/testify/assert"
1718
"github.com/stretchr/testify/require"
19+
"go.uber.org/mock/gomock"
20+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1821

22+
v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
1923
"github.com/stacklok/toolhive/pkg/auth"
2024
"github.com/stacklok/toolhive/pkg/mcp"
25+
transporttypes "github.com/stacklok/toolhive/pkg/transport/types"
26+
transportmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks"
2127
)
2228

2329
// dummyLimiter is a test double for the Limiter interface.
@@ -208,3 +214,70 @@ func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) {
208214
assert.Equal(t, "echo", recorder.toolName)
209215
assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID")
210216
}
217+
218+
func TestDefaultToolNameResolverNilParsedRequest(t *testing.T) {
219+
t.Parallel()
220+
221+
assert.Empty(t, DefaultToolNameResolver(nil))
222+
}
223+
224+
func TestNewMiddlewareUsesCustomToolNameResolver(t *testing.T) {
225+
t.Parallel()
226+
227+
recorder := &recordingLimiter{}
228+
handler := NewMiddleware(recorder, func(*mcp.ParsedMCPRequest) string {
229+
return "resolved-tool"
230+
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
231+
w.WriteHeader(http.StatusOK)
232+
}))
233+
234+
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
235+
req = withParsedMCPRequest(req, "tools/call", "raw-tool", 1)
236+
w := httptest.NewRecorder()
237+
238+
handler.ServeHTTP(w, req)
239+
240+
assert.Equal(t, http.StatusOK, w.Code)
241+
assert.Equal(t, "resolved-tool", recorder.toolName)
242+
}
243+
244+
func TestRateLimitMiddlewareHandlerReturnsConfiguredHandler(t *testing.T) {
245+
t.Parallel()
246+
247+
expected := rateLimitHandler(&dummyLimiter{decision: &Decision{Allowed: true}})
248+
mw := &rateLimitMiddleware{handler: expected}
249+
250+
assert.NotNil(t, mw.Handler())
251+
}
252+
253+
func TestCreateMiddlewareRegistersUsableMiddleware(t *testing.T) {
254+
t.Parallel()
255+
256+
mr := miniredis.RunT(t)
257+
cfg, err := transporttypes.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{
258+
Namespace: "default",
259+
ServerName: "server",
260+
RedisAddr: mr.Addr(),
261+
Config: &v1beta1.RateLimitConfig{
262+
Shared: &v1beta1.RateLimitBucket{
263+
MaxTokens: 1,
264+
RefillPeriod: metav1.Duration{Duration: time.Minute},
265+
},
266+
},
267+
})
268+
require.NoError(t, err)
269+
270+
ctrl := gomock.NewController(t)
271+
runner := transportmocks.NewMockMiddlewareRunner(ctrl)
272+
var registered transporttypes.Middleware
273+
runner.EXPECT().
274+
AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&rateLimitMiddleware{})).
275+
Do(func(_ string, middleware transporttypes.Middleware) {
276+
registered = middleware
277+
})
278+
279+
require.NoError(t, CreateMiddleware(cfg, runner))
280+
require.NotNil(t, registered)
281+
require.NotNil(t, registered.Handler())
282+
require.NoError(t, registered.Close())
283+
}

pkg/vmcp/cli/serve_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,20 @@ func TestValidateQuickModeHost(t *testing.T) {
337337
}
338338
}
339339

340+
func TestVMCPNamespace(t *testing.T) {
341+
t.Run("defaults to local", func(t *testing.T) {
342+
t.Setenv("VMCP_NAMESPACE", "")
343+
344+
assert.Equal(t, "local", vmcpNamespace())
345+
})
346+
347+
t.Run("uses environment value", func(t *testing.T) {
348+
t.Setenv("VMCP_NAMESPACE", "toolhive-system")
349+
350+
assert.Equal(t, "toolhive-system", vmcpNamespace())
351+
})
352+
}
353+
340354
// TestRunDiscovery_ZeroBackends exercises the branch in runDiscovery where the
341355
// discoverer succeeds but returns no backends. The function must return a
342356
// non-error, an empty (non-nil) backend slice, and pass through the client and

pkg/vmcp/server/ratelimit_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@ import (
1515
"github.com/alicebob/miniredis/v2"
1616
"github.com/stretchr/testify/assert"
1717
"github.com/stretchr/testify/require"
18+
"go.uber.org/mock/gomock"
1819
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1920

2021
"github.com/stacklok/toolhive/pkg/auth"
2122
mcpparser "github.com/stacklok/toolhive/pkg/mcp"
2223
ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
24+
"github.com/stacklok/toolhive/pkg/vmcp"
2325
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
26+
discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks"
27+
"github.com/stacklok/toolhive/pkg/vmcp/mocks"
2428
"github.com/stacklok/toolhive/pkg/vmcp/optimizer"
29+
routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks"
2530
)
2631

2732
func TestBuildRateLimitMiddlewareDisabledWithoutConfig(t *testing.T) {
@@ -55,6 +60,82 @@ func TestBuildRateLimitMiddlewareRequiresRedisSessionStorage(t *testing.T) {
5560
assert.Nil(t, cleanup)
5661
}
5762

63+
func TestBuildRateLimitMiddlewareRequiresRedisAddress(t *testing.T) {
64+
t.Parallel()
65+
66+
s := &Server{
67+
config: &Config{
68+
Name: "vmcp",
69+
Namespace: "default",
70+
RateLimiting: sharedRateLimitConfig(1),
71+
SessionStorage: &vmcpconfig.SessionStorageConfig{
72+
Provider: "redis",
73+
},
74+
},
75+
}
76+
77+
middleware, cleanup, err := s.buildRateLimitMiddleware(t.Context())
78+
79+
require.Error(t, err)
80+
assert.Contains(t, err.Error(), "requires Redis session storage address")
81+
assert.Nil(t, middleware)
82+
assert.Nil(t, cleanup)
83+
}
84+
85+
func TestBuildRateLimitMiddlewareRedisPingFailure(t *testing.T) {
86+
t.Parallel()
87+
88+
ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond)
89+
defer cancel()
90+
s := &Server{
91+
config: &Config{
92+
Name: "vmcp",
93+
Namespace: "default",
94+
RateLimiting: sharedRateLimitConfig(1),
95+
SessionStorage: &vmcpconfig.SessionStorageConfig{
96+
Provider: "redis",
97+
Address: "127.0.0.1:1",
98+
},
99+
},
100+
}
101+
102+
middleware, cleanup, err := s.buildRateLimitMiddleware(ctx)
103+
104+
require.Error(t, err)
105+
assert.Contains(t, err.Error(), "failed to connect to Redis")
106+
assert.Nil(t, middleware)
107+
assert.Nil(t, cleanup)
108+
}
109+
110+
func TestBuildRateLimitMiddlewareInvalidRateLimitConfig(t *testing.T) {
111+
t.Parallel()
112+
113+
mr := miniredis.RunT(t)
114+
s := &Server{
115+
config: &Config{
116+
Name: "vmcp",
117+
Namespace: "default",
118+
RateLimiting: &ratelimittypes.RateLimitConfig{
119+
Shared: &ratelimittypes.RateLimitBucket{
120+
MaxTokens: 0,
121+
RefillPeriod: metav1.Duration{Duration: time.Minute},
122+
},
123+
},
124+
SessionStorage: &vmcpconfig.SessionStorageConfig{
125+
Provider: "redis",
126+
Address: mr.Addr(),
127+
},
128+
},
129+
}
130+
131+
middleware, cleanup, err := s.buildRateLimitMiddleware(t.Context())
132+
133+
require.Error(t, err)
134+
assert.Contains(t, err.Error(), "failed to create rate limiter")
135+
assert.Nil(t, middleware)
136+
assert.Nil(t, cleanup)
137+
}
138+
58139
func TestRateLimitMiddlewarePerUserSharedAcrossTools(t *testing.T) {
59140
t.Parallel()
60141

@@ -150,6 +231,96 @@ func TestRateLimitToolNameFallsBackToCallTool(t *testing.T) {
150231
assert.Equal(t, "call_tool", s.rateLimitToolName(parsed))
151232
}
152233

234+
func TestRateLimitToolNameNilParsedRequest(t *testing.T) {
235+
t.Parallel()
236+
237+
s := &Server{config: &Config{}}
238+
239+
assert.Empty(t, s.rateLimitToolName(nil))
240+
}
241+
242+
func TestRateLimitToolNameOptimizerFallsBackForInvalidInnerToolName(t *testing.T) {
243+
t.Parallel()
244+
245+
s := &Server{config: &Config{OptimizerConfig: &optimizer.Config{}}}
246+
parsed := &mcpparser.ParsedMCPRequest{
247+
Method: "tools/call",
248+
ResourceID: "call_tool",
249+
Arguments: map[string]any{"tool_name": 123},
250+
}
251+
252+
assert.Equal(t, "call_tool", s.rateLimitToolName(parsed))
253+
}
254+
255+
func TestApplyRateLimitingWrapsConfiguredMiddleware(t *testing.T) {
256+
t.Parallel()
257+
258+
s := &Server{
259+
rateLimitMiddleware: func(next http.Handler) http.Handler {
260+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
261+
w.Header().Set("X-Rate-Limit-Test", "wrapped")
262+
next.ServeHTTP(w, r)
263+
})
264+
},
265+
}
266+
handler := s.applyRateLimiting(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
267+
w.WriteHeader(http.StatusAccepted)
268+
}))
269+
270+
rec := httptest.NewRecorder()
271+
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
272+
handler.ServeHTTP(rec, req)
273+
274+
assert.Equal(t, http.StatusAccepted, rec.Code)
275+
assert.Equal(t, "wrapped", rec.Header().Get("X-Rate-Limit-Test"))
276+
}
277+
278+
func TestApplyAuthorizationWrapsConfiguredMiddleware(t *testing.T) {
279+
t.Parallel()
280+
281+
s := &Server{config: &Config{
282+
AuthzMiddleware: func(next http.Handler) http.Handler {
283+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
284+
w.Header().Set("X-Authz-Test", "wrapped")
285+
next.ServeHTTP(w, r)
286+
})
287+
},
288+
}}
289+
handler := s.applyAuthorization(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
290+
w.WriteHeader(http.StatusAccepted)
291+
}))
292+
293+
rec := httptest.NewRecorder()
294+
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
295+
handler.ServeHTTP(rec, req)
296+
297+
assert.Equal(t, http.StatusAccepted, rec.Code)
298+
assert.Equal(t, "wrapped", rec.Header().Get("X-Authz-Test"))
299+
}
300+
301+
func TestNewDefaultsNamespace(t *testing.T) {
302+
t.Parallel()
303+
304+
ctrl := gomock.NewController(t)
305+
t.Cleanup(ctrl.Finish)
306+
mockRouter := routerMocks.NewMockRouter(ctrl)
307+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
308+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
309+
310+
s, err := New(
311+
t.Context(),
312+
&Config{SessionFactory: testMinimalFactory()},
313+
mockRouter,
314+
mockBackendClient,
315+
mockDiscoveryMgr,
316+
vmcp.NewImmutableRegistry([]vmcp.Backend{}),
317+
nil,
318+
)
319+
320+
require.NoError(t, err)
321+
assert.Equal(t, "local", s.config.Namespace)
322+
}
323+
153324
func newTestRateLimitHandler(t *testing.T, cfg *Config) http.Handler {
154325
t.Helper()
155326

0 commit comments

Comments
 (0)