Skip to content

Commit 24cc607

Browse files
committed
feat: enable MCP parameter headers and add validation tests using internal JSON unmarshaling
1 parent 005d33d commit 24cc607

2 files changed

Lines changed: 287 additions & 3 deletions

File tree

mcp/streamable_headers.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ const (
2727
mcpHeaderExtension = "x-mcp-header"
2828
)
2929

30+
// ---------------------------------------------------------------------------
31+
// Shared helpers (used by both client and server)
32+
// ---------------------------------------------------------------------------
33+
3034
func extractName(method string, params json.RawMessage) (string, bool) {
3135
switch method {
3236
case "tools/call":
@@ -107,7 +111,7 @@ func primitiveToString(value any) string {
107111
// (string, float64, or bool). Returns nil for non-primitive types.
108112
func unmarshalPrimitive(raw json.RawMessage) any {
109113
var val any
110-
if err := json.Unmarshal(raw, &val); err != nil {
114+
if err := internaljson.Unmarshal(raw, &val); err != nil {
111115
return nil
112116
}
113117
switch val.(type) {
@@ -118,6 +122,10 @@ func unmarshalPrimitive(raw json.RawMessage) any {
118122
}
119123
}
120124

125+
// ---------------------------------------------------------------------------
126+
// Client-side helpers
127+
// ---------------------------------------------------------------------------
128+
121129
// setStandardHeaders populates standard MCP headers.
122130
// It requires the protocol version header to be set.
123131
func setStandardHeaders(header http.Header, msg jsonrpc.Message) {
@@ -153,7 +161,7 @@ func setParamHeaders(header http.Header, tool *Tool, params json.RawMessage) {
153161
var raw struct {
154162
Arguments map[string]json.RawMessage `json:"arguments"`
155163
}
156-
if err := json.Unmarshal(params, &raw); err != nil || raw.Arguments == nil {
164+
if err := internaljson.Unmarshal(params, &raw); err != nil || raw.Arguments == nil {
157165
return
158166
}
159167

@@ -271,6 +279,10 @@ func validateHeaderName(name string) error {
271279
return nil
272280
}
273281

282+
// ---------------------------------------------------------------------------
283+
// Server-side helpers
284+
// ---------------------------------------------------------------------------
285+
274286
func validateMcpHeaders(header http.Header, msg jsonrpc.Message, tool *Tool) error {
275287
protocolVersion := header.Get(protocolVersionHeader)
276288
if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders {
@@ -319,7 +331,7 @@ func validateParamHeaders(header http.Header, msg *jsonrpc.Request, tool *Tool)
319331
var raw struct {
320332
Arguments map[string]json.RawMessage `json:"arguments"`
321333
}
322-
if err := json.Unmarshal(msg.Params, &raw); err != nil {
334+
if err := internaljson.Unmarshal(msg.Params, &raw); err != nil {
323335
return nil
324336
}
325337

mcp/streamable_test.go

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,6 +2170,278 @@ func TestStreamableMcpHeaderVersionGating(t *testing.T) {
21702170
})
21712171
}
21722172

2173+
// TestStreamableParamHeadersClientSetsHeaders verifies that the client sets
2174+
// Mcp-Param-* headers on tool calls when the tool has x-mcp-header annotations.
2175+
func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) {
2176+
orig := supportedProtocolVersions
2177+
supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders)
2178+
t.Cleanup(func() { supportedProtocolVersions = orig })
2179+
2180+
server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil)
2181+
server.AddTool(
2182+
&Tool{
2183+
Name: "execute_sql",
2184+
InputSchema: map[string]any{
2185+
"type": "object",
2186+
"properties": map[string]any{
2187+
"region": map[string]any{
2188+
"type": "string",
2189+
"x-mcp-header": "Region",
2190+
},
2191+
"query": map[string]any{
2192+
"type": "string",
2193+
},
2194+
},
2195+
},
2196+
},
2197+
func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
2198+
return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil
2199+
})
2200+
2201+
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
2202+
defer handler.closeAll()
2203+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
2204+
defer httpServer.Close()
2205+
2206+
var capturedHeaders http.Header
2207+
customClient := &http.Client{
2208+
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
2209+
if req.Header.Get(methodHeader) == "tools/call" {
2210+
capturedHeaders = req.Header.Clone()
2211+
}
2212+
return http.DefaultTransport.RoundTrip(req)
2213+
}),
2214+
}
2215+
2216+
clientTransport := &StreamableClientTransport{
2217+
Endpoint: httpServer.URL,
2218+
HTTPClient: customClient,
2219+
}
2220+
2221+
client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil)
2222+
ctx := context.Background()
2223+
session, err := client.Connect(ctx, clientTransport, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders})
2224+
if err != nil {
2225+
t.Fatal(err)
2226+
}
2227+
defer session.Close()
2228+
2229+
// ListTools to populate the tool cache (needed for param headers).
2230+
if _, err := session.ListTools(ctx, nil); err != nil {
2231+
t.Fatal(err)
2232+
}
2233+
2234+
_, err = session.CallTool(ctx, &CallToolParams{
2235+
Name: "execute_sql",
2236+
Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"},
2237+
})
2238+
if err != nil {
2239+
t.Fatal(err)
2240+
}
2241+
2242+
if capturedHeaders == nil {
2243+
t.Fatal("no tool call headers captured")
2244+
}
2245+
if got := capturedHeaders.Get(methodHeader); got != "tools/call" {
2246+
t.Errorf("Mcp-Method = %q, want %q", got, "tools/call")
2247+
}
2248+
if got := capturedHeaders.Get(nameHeader); got != "execute_sql" {
2249+
t.Errorf("Mcp-Name = %q, want %q", got, "execute_sql")
2250+
}
2251+
if got := capturedHeaders.Get(paramHeaderPrefix + "Region"); got != "us-west1" {
2252+
t.Errorf("Mcp-Param-Region = %q, want %q", got, "us-west1")
2253+
}
2254+
if got := capturedHeaders.Get("Mcp-Param-query"); got != "" {
2255+
t.Errorf("non-annotated param got header: Mcp-Param-query = %q", got)
2256+
}
2257+
}
2258+
2259+
// TestStreamableParamHeadersServerValidation verifies that the server
2260+
// validates Mcp-Param-* headers against the body for tools with
2261+
// x-mcp-header annotations.
2262+
func TestStreamableParamHeadersServerValidation(t *testing.T) {
2263+
orig := supportedProtocolVersions
2264+
supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders)
2265+
t.Cleanup(func() { supportedProtocolVersions = orig })
2266+
2267+
server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil)
2268+
server.AddTool(
2269+
&Tool{
2270+
Name: "execute_sql",
2271+
InputSchema: map[string]any{
2272+
"type": "object",
2273+
"properties": map[string]any{
2274+
"region": map[string]any{
2275+
"type": "string",
2276+
"x-mcp-header": "Region",
2277+
},
2278+
"query": map[string]any{
2279+
"type": "string",
2280+
},
2281+
},
2282+
},
2283+
},
2284+
func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
2285+
return &CallToolResult{}, nil
2286+
})
2287+
2288+
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
2289+
defer handler.closeAll()
2290+
2291+
initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: minVersionForStandardHeaders})
2292+
initResp := resp(1, &InitializeResult{
2293+
Capabilities: &ServerCapabilities{
2294+
Logging: &LoggingCapabilities{},
2295+
Tools: &ToolCapabilities{ListChanged: true},
2296+
},
2297+
ProtocolVersion: minVersionForStandardHeaders,
2298+
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
2299+
}, nil)
2300+
2301+
testStreamableHandler(t, handler, []streamableRequest{
2302+
{
2303+
method: "POST",
2304+
messages: []jsonrpc.Message{initReq},
2305+
wantStatusCode: http.StatusOK,
2306+
wantMessages: []jsonrpc.Message{initResp},
2307+
wantSessionID: true,
2308+
},
2309+
{
2310+
method: "POST",
2311+
headers: http.Header{
2312+
protocolVersionHeader: {minVersionForStandardHeaders},
2313+
methodHeader: {notificationInitialized},
2314+
},
2315+
messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})},
2316+
wantStatusCode: http.StatusAccepted,
2317+
},
2318+
// Correct param header should succeed.
2319+
{
2320+
method: "POST",
2321+
headers: http.Header{
2322+
protocolVersionHeader: {minVersionForStandardHeaders},
2323+
methodHeader: {"tools/call"},
2324+
nameHeader: {"execute_sql"},
2325+
paramHeaderPrefix + "Region": {"us-west1"},
2326+
},
2327+
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{
2328+
Name: "execute_sql",
2329+
Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"},
2330+
})},
2331+
wantStatusCode: http.StatusOK,
2332+
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)},
2333+
},
2334+
// Mismatched param header value should fail.
2335+
{
2336+
method: "POST",
2337+
headers: http.Header{
2338+
protocolVersionHeader: {minVersionForStandardHeaders},
2339+
methodHeader: {"tools/call"},
2340+
nameHeader: {"execute_sql"},
2341+
paramHeaderPrefix + "Region": {"eu-central1"},
2342+
},
2343+
messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{
2344+
Name: "execute_sql",
2345+
Arguments: map[string]any{"region": "us-west1"},
2346+
})},
2347+
wantStatusCode: http.StatusBadRequest,
2348+
wantBodyContaining: "header mismatch",
2349+
},
2350+
// Missing param header when body has the argument should fail.
2351+
{
2352+
method: "POST",
2353+
headers: http.Header{
2354+
protocolVersionHeader: {minVersionForStandardHeaders},
2355+
methodHeader: {"tools/call"},
2356+
nameHeader: {"execute_sql"},
2357+
},
2358+
messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{
2359+
Name: "execute_sql",
2360+
Arguments: map[string]any{"region": "us-west1"},
2361+
})},
2362+
wantStatusCode: http.StatusBadRequest,
2363+
wantBodyContaining: "missing",
2364+
},
2365+
})
2366+
}
2367+
2368+
// TestStreamableFilterValidToolsIntegration verifies that invalid tools
2369+
// (with bad x-mcp-header annotations) are filtered out when the client
2370+
// calls ListTools.
2371+
func TestStreamableFilterValidToolsIntegration(t *testing.T) {
2372+
orig := supportedProtocolVersions
2373+
supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders)
2374+
t.Cleanup(func() { supportedProtocolVersions = orig })
2375+
2376+
server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil)
2377+
noop := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
2378+
return &CallToolResult{}, nil
2379+
}
2380+
2381+
// Valid tool with correct x-mcp-header annotation.
2382+
server.AddTool(&Tool{
2383+
Name: "valid-tool",
2384+
InputSchema: map[string]any{
2385+
"type": "object",
2386+
"properties": map[string]any{
2387+
"region": map[string]any{
2388+
"type": "string",
2389+
"x-mcp-header": "Region",
2390+
},
2391+
},
2392+
},
2393+
}, noop)
2394+
2395+
// Invalid tool: x-mcp-header on an array type.
2396+
server.AddTool(&Tool{
2397+
Name: "invalid-tool",
2398+
InputSchema: map[string]any{
2399+
"type": "object",
2400+
"properties": map[string]any{
2401+
"items": map[string]any{
2402+
"type": "array",
2403+
"x-mcp-header": "Items",
2404+
},
2405+
},
2406+
},
2407+
}, noop)
2408+
2409+
// Tool with no x-mcp-header annotations (always valid).
2410+
server.AddTool(&Tool{
2411+
Name: "plain-tool",
2412+
InputSchema: &jsonschema.Schema{Type: "object"},
2413+
}, noop)
2414+
2415+
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
2416+
defer handler.closeAll()
2417+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
2418+
defer httpServer.Close()
2419+
2420+
client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil)
2421+
ctx := context.Background()
2422+
session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders})
2423+
if err != nil {
2424+
t.Fatal(err)
2425+
}
2426+
defer session.Close()
2427+
2428+
result, err := session.ListTools(ctx, nil)
2429+
if err != nil {
2430+
t.Fatal(err)
2431+
}
2432+
2433+
toolNames := make([]string, len(result.Tools))
2434+
for i, tool := range result.Tools {
2435+
toolNames[i] = tool.Name
2436+
}
2437+
sort.Strings(toolNames)
2438+
2439+
wantNames := []string{"plain-tool", "valid-tool"}
2440+
if !slices.Equal(toolNames, wantNames) {
2441+
t.Errorf("ListTools returned %v, want %v", toolNames, wantNames)
2442+
}
2443+
}
2444+
21732445
// TestStreamable405AllowHeader verifies RFC 9110 §15.5.6 compliance:
21742446
// 405 Method Not Allowed responses MUST include an Allow header.
21752447
func TestStreamable405AllowHeader(t *testing.T) {

0 commit comments

Comments
 (0)