@@ -15,13 +15,18 @@ import (
1515 "github.com/alicebob/miniredis/v2"
1616 "github.com/stretchr/testify/assert"
1717 "github.com/stretchr/testify/require"
18+ "go.uber.org/mock/gomock"
1819 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1920
2021 "github.com/stacklok/toolhive/pkg/auth"
2122 mcpparser "github.com/stacklok/toolhive/pkg/mcp"
2223 ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
24+ "github.com/stacklok/toolhive/pkg/vmcp"
2325 vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
26+ discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks"
27+ "github.com/stacklok/toolhive/pkg/vmcp/mocks"
2428 "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
29+ routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks"
2530)
2631
2732func TestBuildRateLimitMiddlewareDisabledWithoutConfig (t * testing.T ) {
@@ -55,6 +60,82 @@ func TestBuildRateLimitMiddlewareRequiresRedisSessionStorage(t *testing.T) {
5560 assert .Nil (t , cleanup )
5661}
5762
63+ func TestBuildRateLimitMiddlewareRequiresRedisAddress (t * testing.T ) {
64+ t .Parallel ()
65+
66+ s := & Server {
67+ config : & Config {
68+ Name : "vmcp" ,
69+ Namespace : "default" ,
70+ RateLimiting : sharedRateLimitConfig (1 ),
71+ SessionStorage : & vmcpconfig.SessionStorageConfig {
72+ Provider : "redis" ,
73+ },
74+ },
75+ }
76+
77+ middleware , cleanup , err := s .buildRateLimitMiddleware (t .Context ())
78+
79+ require .Error (t , err )
80+ assert .Contains (t , err .Error (), "requires Redis session storage address" )
81+ assert .Nil (t , middleware )
82+ assert .Nil (t , cleanup )
83+ }
84+
85+ func TestBuildRateLimitMiddlewareRedisPingFailure (t * testing.T ) {
86+ t .Parallel ()
87+
88+ ctx , cancel := context .WithTimeout (t .Context (), 100 * time .Millisecond )
89+ defer cancel ()
90+ s := & Server {
91+ config : & Config {
92+ Name : "vmcp" ,
93+ Namespace : "default" ,
94+ RateLimiting : sharedRateLimitConfig (1 ),
95+ SessionStorage : & vmcpconfig.SessionStorageConfig {
96+ Provider : "redis" ,
97+ Address : "127.0.0.1:1" ,
98+ },
99+ },
100+ }
101+
102+ middleware , cleanup , err := s .buildRateLimitMiddleware (ctx )
103+
104+ require .Error (t , err )
105+ assert .Contains (t , err .Error (), "failed to connect to Redis" )
106+ assert .Nil (t , middleware )
107+ assert .Nil (t , cleanup )
108+ }
109+
110+ func TestBuildRateLimitMiddlewareInvalidRateLimitConfig (t * testing.T ) {
111+ t .Parallel ()
112+
113+ mr := miniredis .RunT (t )
114+ s := & Server {
115+ config : & Config {
116+ Name : "vmcp" ,
117+ Namespace : "default" ,
118+ RateLimiting : & ratelimittypes.RateLimitConfig {
119+ Shared : & ratelimittypes.RateLimitBucket {
120+ MaxTokens : 0 ,
121+ RefillPeriod : metav1.Duration {Duration : time .Minute },
122+ },
123+ },
124+ SessionStorage : & vmcpconfig.SessionStorageConfig {
125+ Provider : "redis" ,
126+ Address : mr .Addr (),
127+ },
128+ },
129+ }
130+
131+ middleware , cleanup , err := s .buildRateLimitMiddleware (t .Context ())
132+
133+ require .Error (t , err )
134+ assert .Contains (t , err .Error (), "failed to create rate limiter" )
135+ assert .Nil (t , middleware )
136+ assert .Nil (t , cleanup )
137+ }
138+
58139func TestRateLimitMiddlewarePerUserSharedAcrossTools (t * testing.T ) {
59140 t .Parallel ()
60141
@@ -150,6 +231,96 @@ func TestRateLimitToolNameFallsBackToCallTool(t *testing.T) {
150231 assert .Equal (t , "call_tool" , s .rateLimitToolName (parsed ))
151232}
152233
234+ func TestRateLimitToolNameNilParsedRequest (t * testing.T ) {
235+ t .Parallel ()
236+
237+ s := & Server {config : & Config {}}
238+
239+ assert .Empty (t , s .rateLimitToolName (nil ))
240+ }
241+
242+ func TestRateLimitToolNameOptimizerFallsBackForInvalidInnerToolName (t * testing.T ) {
243+ t .Parallel ()
244+
245+ s := & Server {config : & Config {OptimizerConfig : & optimizer.Config {}}}
246+ parsed := & mcpparser.ParsedMCPRequest {
247+ Method : "tools/call" ,
248+ ResourceID : "call_tool" ,
249+ Arguments : map [string ]any {"tool_name" : 123 },
250+ }
251+
252+ assert .Equal (t , "call_tool" , s .rateLimitToolName (parsed ))
253+ }
254+
255+ func TestApplyRateLimitingWrapsConfiguredMiddleware (t * testing.T ) {
256+ t .Parallel ()
257+
258+ s := & Server {
259+ rateLimitMiddleware : func (next http.Handler ) http.Handler {
260+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
261+ w .Header ().Set ("X-Rate-Limit-Test" , "wrapped" )
262+ next .ServeHTTP (w , r )
263+ })
264+ },
265+ }
266+ handler := s .applyRateLimiting (http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
267+ w .WriteHeader (http .StatusAccepted )
268+ }))
269+
270+ rec := httptest .NewRecorder ()
271+ req := httptest .NewRequest (http .MethodPost , "/mcp" , nil )
272+ handler .ServeHTTP (rec , req )
273+
274+ assert .Equal (t , http .StatusAccepted , rec .Code )
275+ assert .Equal (t , "wrapped" , rec .Header ().Get ("X-Rate-Limit-Test" ))
276+ }
277+
278+ func TestApplyAuthorizationWrapsConfiguredMiddleware (t * testing.T ) {
279+ t .Parallel ()
280+
281+ s := & Server {config : & Config {
282+ AuthzMiddleware : func (next http.Handler ) http.Handler {
283+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
284+ w .Header ().Set ("X-Authz-Test" , "wrapped" )
285+ next .ServeHTTP (w , r )
286+ })
287+ },
288+ }}
289+ handler := s .applyAuthorization (http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
290+ w .WriteHeader (http .StatusAccepted )
291+ }))
292+
293+ rec := httptest .NewRecorder ()
294+ req := httptest .NewRequest (http .MethodPost , "/mcp" , nil )
295+ handler .ServeHTTP (rec , req )
296+
297+ assert .Equal (t , http .StatusAccepted , rec .Code )
298+ assert .Equal (t , "wrapped" , rec .Header ().Get ("X-Authz-Test" ))
299+ }
300+
301+ func TestNewDefaultsNamespace (t * testing.T ) {
302+ t .Parallel ()
303+
304+ ctrl := gomock .NewController (t )
305+ t .Cleanup (ctrl .Finish )
306+ mockRouter := routerMocks .NewMockRouter (ctrl )
307+ mockBackendClient := mocks .NewMockBackendClient (ctrl )
308+ mockDiscoveryMgr := discoveryMocks .NewMockManager (ctrl )
309+
310+ s , err := New (
311+ t .Context (),
312+ & Config {SessionFactory : testMinimalFactory ()},
313+ mockRouter ,
314+ mockBackendClient ,
315+ mockDiscoveryMgr ,
316+ vmcp .NewImmutableRegistry ([]vmcp.Backend {}),
317+ nil ,
318+ )
319+
320+ require .NoError (t , err )
321+ assert .Equal (t , "local" , s .config .Namespace )
322+ }
323+
153324func newTestRateLimitHandler (t * testing.T , cfg * Config ) http.Handler {
154325 t .Helper ()
155326
0 commit comments