@@ -2170,6 +2170,278 @@ func TestStreamableMcpHeaderVersionGating(t *testing.T) {
21702170 })
21712171}
21722172
2173+ // TestStreamableParamHeadersClientSetsHeaders verifies that the client sets
2174+ // Mcp-Param-* headers on tool calls when the tool has x-mcp-header annotations.
2175+ func TestStreamableParamHeadersClientSetsHeaders (t * testing.T ) {
2176+ orig := supportedProtocolVersions
2177+ supportedProtocolVersions = append (slices .Clone (orig ), minVersionForStandardHeaders )
2178+ t .Cleanup (func () { supportedProtocolVersions = orig })
2179+
2180+ server := NewServer (& Implementation {Name : "testServer" , Version : "v1.0.0" }, nil )
2181+ server .AddTool (
2182+ & Tool {
2183+ Name : "execute_sql" ,
2184+ InputSchema : map [string ]any {
2185+ "type" : "object" ,
2186+ "properties" : map [string ]any {
2187+ "region" : map [string ]any {
2188+ "type" : "string" ,
2189+ "x-mcp-header" : "Region" ,
2190+ },
2191+ "query" : map [string ]any {
2192+ "type" : "string" ,
2193+ },
2194+ },
2195+ },
2196+ },
2197+ func (ctx context.Context , req * CallToolRequest ) (* CallToolResult , error ) {
2198+ return & CallToolResult {Content : []Content {& TextContent {Text : "ok" }}}, nil
2199+ })
2200+
2201+ handler := NewStreamableHTTPHandler (func (req * http.Request ) * Server { return server }, nil )
2202+ defer handler .closeAll ()
2203+ httpServer := httptest .NewServer (mustNotPanic (t , handler ))
2204+ defer httpServer .Close ()
2205+
2206+ var capturedHeaders http.Header
2207+ customClient := & http.Client {
2208+ Transport : roundTripperFunc (func (req * http.Request ) (* http.Response , error ) {
2209+ if req .Header .Get (methodHeader ) == "tools/call" {
2210+ capturedHeaders = req .Header .Clone ()
2211+ }
2212+ return http .DefaultTransport .RoundTrip (req )
2213+ }),
2214+ }
2215+
2216+ clientTransport := & StreamableClientTransport {
2217+ Endpoint : httpServer .URL ,
2218+ HTTPClient : customClient ,
2219+ }
2220+
2221+ client := NewClient (& Implementation {Name : "testClient" , Version : "v1.0.0" }, nil )
2222+ ctx := context .Background ()
2223+ session , err := client .Connect (ctx , clientTransport , & ClientSessionOptions {protocolVersion : minVersionForStandardHeaders })
2224+ if err != nil {
2225+ t .Fatal (err )
2226+ }
2227+ defer session .Close ()
2228+
2229+ // ListTools to populate the tool cache (needed for param headers).
2230+ if _ , err := session .ListTools (ctx , nil ); err != nil {
2231+ t .Fatal (err )
2232+ }
2233+
2234+ _ , err = session .CallTool (ctx , & CallToolParams {
2235+ Name : "execute_sql" ,
2236+ Arguments : map [string ]any {"region" : "us-west1" , "query" : "SELECT 1" },
2237+ })
2238+ if err != nil {
2239+ t .Fatal (err )
2240+ }
2241+
2242+ if capturedHeaders == nil {
2243+ t .Fatal ("no tool call headers captured" )
2244+ }
2245+ if got := capturedHeaders .Get (methodHeader ); got != "tools/call" {
2246+ t .Errorf ("Mcp-Method = %q, want %q" , got , "tools/call" )
2247+ }
2248+ if got := capturedHeaders .Get (nameHeader ); got != "execute_sql" {
2249+ t .Errorf ("Mcp-Name = %q, want %q" , got , "execute_sql" )
2250+ }
2251+ if got := capturedHeaders .Get (paramHeaderPrefix + "Region" ); got != "us-west1" {
2252+ t .Errorf ("Mcp-Param-Region = %q, want %q" , got , "us-west1" )
2253+ }
2254+ if got := capturedHeaders .Get ("Mcp-Param-query" ); got != "" {
2255+ t .Errorf ("non-annotated param got header: Mcp-Param-query = %q" , got )
2256+ }
2257+ }
2258+
2259+ // TestStreamableParamHeadersServerValidation verifies that the server
2260+ // validates Mcp-Param-* headers against the body for tools with
2261+ // x-mcp-header annotations.
2262+ func TestStreamableParamHeadersServerValidation (t * testing.T ) {
2263+ orig := supportedProtocolVersions
2264+ supportedProtocolVersions = append (slices .Clone (orig ), minVersionForStandardHeaders )
2265+ t .Cleanup (func () { supportedProtocolVersions = orig })
2266+
2267+ server := NewServer (& Implementation {Name : "testServer" , Version : "v1.0.0" }, nil )
2268+ server .AddTool (
2269+ & Tool {
2270+ Name : "execute_sql" ,
2271+ InputSchema : map [string ]any {
2272+ "type" : "object" ,
2273+ "properties" : map [string ]any {
2274+ "region" : map [string ]any {
2275+ "type" : "string" ,
2276+ "x-mcp-header" : "Region" ,
2277+ },
2278+ "query" : map [string ]any {
2279+ "type" : "string" ,
2280+ },
2281+ },
2282+ },
2283+ },
2284+ func (ctx context.Context , req * CallToolRequest ) (* CallToolResult , error ) {
2285+ return & CallToolResult {}, nil
2286+ })
2287+
2288+ handler := NewStreamableHTTPHandler (func (req * http.Request ) * Server { return server }, nil )
2289+ defer handler .closeAll ()
2290+
2291+ initReq := req (1 , methodInitialize , & InitializeParams {ProtocolVersion : minVersionForStandardHeaders })
2292+ initResp := resp (1 , & InitializeResult {
2293+ Capabilities : & ServerCapabilities {
2294+ Logging : & LoggingCapabilities {},
2295+ Tools : & ToolCapabilities {ListChanged : true },
2296+ },
2297+ ProtocolVersion : minVersionForStandardHeaders ,
2298+ ServerInfo : & Implementation {Name : "testServer" , Version : "v1.0.0" },
2299+ }, nil )
2300+
2301+ testStreamableHandler (t , handler , []streamableRequest {
2302+ {
2303+ method : "POST" ,
2304+ messages : []jsonrpc.Message {initReq },
2305+ wantStatusCode : http .StatusOK ,
2306+ wantMessages : []jsonrpc.Message {initResp },
2307+ wantSessionID : true ,
2308+ },
2309+ {
2310+ method : "POST" ,
2311+ headers : http.Header {
2312+ protocolVersionHeader : {minVersionForStandardHeaders },
2313+ methodHeader : {notificationInitialized },
2314+ },
2315+ messages : []jsonrpc.Message {req (0 , notificationInitialized , & InitializedParams {})},
2316+ wantStatusCode : http .StatusAccepted ,
2317+ },
2318+ // Correct param header should succeed.
2319+ {
2320+ method : "POST" ,
2321+ headers : http.Header {
2322+ protocolVersionHeader : {minVersionForStandardHeaders },
2323+ methodHeader : {"tools/call" },
2324+ nameHeader : {"execute_sql" },
2325+ paramHeaderPrefix + "Region" : {"us-west1" },
2326+ },
2327+ messages : []jsonrpc.Message {req (2 , "tools/call" , & CallToolParams {
2328+ Name : "execute_sql" ,
2329+ Arguments : map [string ]any {"region" : "us-west1" , "query" : "SELECT 1" },
2330+ })},
2331+ wantStatusCode : http .StatusOK ,
2332+ wantMessages : []jsonrpc.Message {resp (2 , & CallToolResult {Content : []Content {}}, nil )},
2333+ },
2334+ // Mismatched param header value should fail.
2335+ {
2336+ method : "POST" ,
2337+ headers : http.Header {
2338+ protocolVersionHeader : {minVersionForStandardHeaders },
2339+ methodHeader : {"tools/call" },
2340+ nameHeader : {"execute_sql" },
2341+ paramHeaderPrefix + "Region" : {"eu-central1" },
2342+ },
2343+ messages : []jsonrpc.Message {req (3 , "tools/call" , & CallToolParams {
2344+ Name : "execute_sql" ,
2345+ Arguments : map [string ]any {"region" : "us-west1" },
2346+ })},
2347+ wantStatusCode : http .StatusBadRequest ,
2348+ wantBodyContaining : "header mismatch" ,
2349+ },
2350+ // Missing param header when body has the argument should fail.
2351+ {
2352+ method : "POST" ,
2353+ headers : http.Header {
2354+ protocolVersionHeader : {minVersionForStandardHeaders },
2355+ methodHeader : {"tools/call" },
2356+ nameHeader : {"execute_sql" },
2357+ },
2358+ messages : []jsonrpc.Message {req (4 , "tools/call" , & CallToolParams {
2359+ Name : "execute_sql" ,
2360+ Arguments : map [string ]any {"region" : "us-west1" },
2361+ })},
2362+ wantStatusCode : http .StatusBadRequest ,
2363+ wantBodyContaining : "missing" ,
2364+ },
2365+ })
2366+ }
2367+
2368+ // TestStreamableFilterValidToolsIntegration verifies that invalid tools
2369+ // (with bad x-mcp-header annotations) are filtered out when the client
2370+ // calls ListTools.
2371+ func TestStreamableFilterValidToolsIntegration (t * testing.T ) {
2372+ orig := supportedProtocolVersions
2373+ supportedProtocolVersions = append (slices .Clone (orig ), minVersionForStandardHeaders )
2374+ t .Cleanup (func () { supportedProtocolVersions = orig })
2375+
2376+ server := NewServer (& Implementation {Name : "testServer" , Version : "v1.0.0" }, nil )
2377+ noop := func (ctx context.Context , req * CallToolRequest ) (* CallToolResult , error ) {
2378+ return & CallToolResult {}, nil
2379+ }
2380+
2381+ // Valid tool with correct x-mcp-header annotation.
2382+ server .AddTool (& Tool {
2383+ Name : "valid-tool" ,
2384+ InputSchema : map [string ]any {
2385+ "type" : "object" ,
2386+ "properties" : map [string ]any {
2387+ "region" : map [string ]any {
2388+ "type" : "string" ,
2389+ "x-mcp-header" : "Region" ,
2390+ },
2391+ },
2392+ },
2393+ }, noop )
2394+
2395+ // Invalid tool: x-mcp-header on an array type.
2396+ server .AddTool (& Tool {
2397+ Name : "invalid-tool" ,
2398+ InputSchema : map [string ]any {
2399+ "type" : "object" ,
2400+ "properties" : map [string ]any {
2401+ "items" : map [string ]any {
2402+ "type" : "array" ,
2403+ "x-mcp-header" : "Items" ,
2404+ },
2405+ },
2406+ },
2407+ }, noop )
2408+
2409+ // Tool with no x-mcp-header annotations (always valid).
2410+ server .AddTool (& Tool {
2411+ Name : "plain-tool" ,
2412+ InputSchema : & jsonschema.Schema {Type : "object" },
2413+ }, noop )
2414+
2415+ handler := NewStreamableHTTPHandler (func (req * http.Request ) * Server { return server }, nil )
2416+ defer handler .closeAll ()
2417+ httpServer := httptest .NewServer (mustNotPanic (t , handler ))
2418+ defer httpServer .Close ()
2419+
2420+ client := NewClient (& Implementation {Name : "testClient" , Version : "v1.0.0" }, nil )
2421+ ctx := context .Background ()
2422+ session , err := client .Connect (ctx , & StreamableClientTransport {Endpoint : httpServer .URL }, & ClientSessionOptions {protocolVersion : minVersionForStandardHeaders })
2423+ if err != nil {
2424+ t .Fatal (err )
2425+ }
2426+ defer session .Close ()
2427+
2428+ result , err := session .ListTools (ctx , nil )
2429+ if err != nil {
2430+ t .Fatal (err )
2431+ }
2432+
2433+ toolNames := make ([]string , len (result .Tools ))
2434+ for i , tool := range result .Tools {
2435+ toolNames [i ] = tool .Name
2436+ }
2437+ sort .Strings (toolNames )
2438+
2439+ wantNames := []string {"plain-tool" , "valid-tool" }
2440+ if ! slices .Equal (toolNames , wantNames ) {
2441+ t .Errorf ("ListTools returned %v, want %v" , toolNames , wantNames )
2442+ }
2443+ }
2444+
21732445// TestStreamable405AllowHeader verifies RFC 9110 §15.5.6 compliance:
21742446// 405 Method Not Allowed responses MUST include an Allow header.
21752447func TestStreamable405AllowHeader (t * testing.T ) {
0 commit comments