Skip to content

Commit 909372d

Browse files
mcp: HTTP Header Standardization for x-mcp-header (#915)
## Description Implements [SEP-2243](https://modelcontextprotocol.io/seps/2243-http-standardization) (HTTP Header Standardization) for x-mcp-param custom header. Fixes #905 --------- Co-authored-by: Maciej Kisiel <mkisiel@google.com>
1 parent d10c315 commit 909372d

7 files changed

Lines changed: 1512 additions & 15 deletions

File tree

mcp/client.go

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ type ClientOptions struct {
160160
KeepAlive time.Duration
161161
}
162162

163+
// toolContextKeyType is the context key type for passing tool definitions
164+
// from CallTool to the transport layer.
165+
type toolContextKeyType struct{}
166+
167+
var toolContextKey = toolContextKeyType{}
168+
163169
// bind implements the binder[*ClientSession] interface, so that Clients can
164170
// be connected using [connect].
165171
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession {
@@ -318,6 +324,13 @@ type ClientSession struct {
318324
// Pending URL elicitations waiting for completion notifications.
319325
pendingElicitationsMu sync.Mutex
320326
pendingElicitations map[string]chan struct{}
327+
328+
// toolCacheMu guards toolCache.
329+
toolCacheMu sync.RWMutex
330+
// toolCache stores tool definitions keyed by name.
331+
// It is used to look up x-mcp-header annotations when
332+
// constructing Mcp-Param-* headers for tools/call requests.
333+
toolCache map[string]*Tool
321334
}
322335

323336
type clientSessionState struct {
@@ -363,6 +376,21 @@ func (cs *ClientSession) Wait() error {
363376
return cs.conn.Wait()
364377
}
365378

379+
func (cs *ClientSession) cacheTools(tools []*Tool) {
380+
cs.toolCacheMu.Lock()
381+
defer cs.toolCacheMu.Unlock()
382+
cs.toolCache = make(map[string]*Tool, len(tools))
383+
for _, tool := range tools {
384+
cs.toolCache[tool.Name] = tool
385+
}
386+
}
387+
388+
func (cs *ClientSession) getCachedTool(name string) *Tool {
389+
cs.toolCacheMu.RLock()
390+
defer cs.toolCacheMu.RUnlock()
391+
return cs.toolCache[name]
392+
}
393+
366394
// registerElicitationWaiter registers a waiter for an elicitation complete
367395
// notification with the given elicitation ID. It returns two functions: an await
368396
// function that waits for the notification or context cancellation, and a cleanup
@@ -981,7 +1009,13 @@ func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams)
9811009

9821010
// ListTools lists tools that are currently available on the server.
9831011
func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) {
984-
return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params)))
1012+
result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params)))
1013+
if err != nil {
1014+
return nil, err
1015+
}
1016+
result.Tools = filterValidTools(cs.client.opts.Logger, result.Tools)
1017+
cs.cacheTools(result.Tools)
1018+
return result, nil
9851019
}
9861020

9871021
// CallTool calls the tool with the given parameters.
@@ -995,6 +1029,9 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (
9951029
// Avoid sending nil over the wire.
9961030
params.Arguments = map[string]any{}
9971031
}
1032+
if tool := cs.getCachedTool(params.Name); tool != nil {
1033+
ctx = context.WithValue(ctx, toolContextKey, tool)
1034+
}
9981035
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))
9991036
}
10001037

mcp/client_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,80 @@ func TestClientCapabilities(t *testing.T) {
440440
}
441441
}
442442

443+
func TestToolCache(t *testing.T) {
444+
tool1 := &Tool{Name: "tool1", Description: "first"}
445+
tool2 := &Tool{Name: "tool2", Description: "second"}
446+
tool1Updated := &Tool{Name: "tool1", Description: "updated"}
447+
448+
testCases := []struct {
449+
name string
450+
cacheBatches [][]*Tool
451+
lookup string
452+
want *Tool
453+
}{
454+
{
455+
name: "empty cache",
456+
lookup: "tool1",
457+
want: nil,
458+
},
459+
{
460+
name: "single tool found",
461+
cacheBatches: [][]*Tool{{tool1}},
462+
lookup: "tool1",
463+
want: tool1,
464+
},
465+
{
466+
name: "unknown tool",
467+
cacheBatches: [][]*Tool{{tool1}},
468+
lookup: "nonexistent",
469+
want: nil,
470+
},
471+
{
472+
name: "multiple tools single batch",
473+
cacheBatches: [][]*Tool{{tool1, tool2}},
474+
lookup: "tool2",
475+
want: tool2,
476+
},
477+
{
478+
name: "replace clears old entries",
479+
cacheBatches: [][]*Tool{{tool1}, {tool2}},
480+
lookup: "tool1",
481+
want: nil,
482+
},
483+
{
484+
name: "replace keeps new entries",
485+
cacheBatches: [][]*Tool{{tool1}, {tool2}},
486+
lookup: "tool2",
487+
want: tool2,
488+
},
489+
{
490+
name: "overwrite existing entry",
491+
cacheBatches: [][]*Tool{{tool1}, {tool1Updated}},
492+
lookup: "tool1",
493+
want: tool1Updated,
494+
},
495+
{
496+
name: "empty batch no-op",
497+
cacheBatches: [][]*Tool{{}},
498+
lookup: "tool1",
499+
want: nil,
500+
},
501+
}
502+
503+
for _, tc := range testCases {
504+
t.Run(tc.name, func(t *testing.T) {
505+
cs := &ClientSession{}
506+
for _, batch := range tc.cacheBatches {
507+
cs.cacheTools(batch)
508+
}
509+
got := cs.getCachedTool(tc.lookup)
510+
if diff := cmp.Diff(tc.want, got); diff != "" {
511+
t.Errorf("getCachedTool(%q) mismatch (-want +got):\n%s", tc.lookup, diff)
512+
}
513+
})
514+
}
515+
}
516+
443517
func TestClientCapabilitiesOverWire(t *testing.T) {
444518
testCases := []struct {
445519
name string

mcp/server.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
280280
}
281281
}
282282
}
283+
if err := validateParamHeaderAnnotations(t); err != nil {
284+
panic(fmt.Errorf("AddTool %q: invalid parameter header annotations: %v", t.Name, err))
285+
}
283286
st := &serverTool{tool: t, handler: h}
284287
// Assume there was a change, since add replaces existing tools.
285288
// (It's possible a tool was replaced with an identical one, but not worth checking.)
@@ -753,10 +756,15 @@ func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListTools
753756
})
754757
}
755758

756-
func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
759+
// getServerTool looks up a server tool by name.
760+
func (s *Server) getServerTool(name string) (*serverTool, bool) {
757761
s.mu.Lock()
758-
st, ok := s.tools.get(req.Params.Name)
759-
s.mu.Unlock()
762+
defer s.mu.Unlock()
763+
return s.tools.get(name)
764+
}
765+
766+
func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
767+
st, ok := s.getServerTool(req.Params.Name)
760768
if !ok {
761769
return nil, &jsonrpc.Error{
762770
Code: jsonrpc.CodeInvalidParams,

mcp/streamable.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
491491
http.Error(w, "failed connection", http.StatusInternalServerError)
492492
return
493493
}
494+
transport.connection.toolLookup = server.getServerTool
494495
// Capture the user ID from the token info to enable session hijacking
495496
// prevention on subsequent requests.
496497
var userID string
@@ -669,6 +670,8 @@ type streamableServerConn struct {
669670

670671
logger *slog.Logger
671672

673+
toolLookup func(name string) (*serverTool, bool)
674+
672675
incoming chan jsonrpc.Message // messages from the client to the server
673676

674677
mu sync.Mutex // guards all fields below
@@ -1202,9 +1205,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
12021205
}
12031206
}
12041207

1205-
// Validate MCP standard headers (Mcp-Method, Mcp-Name)
1208+
// Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*)
12061209
if !isBatch && len(incoming) == 1 {
1207-
if err := validateMcpHeaders(req.Header, incoming[0]); err != nil {
1210+
if err := validateMcpHeaders(req.Header, incoming[0], c.toolLookup); err != nil {
12081211
resp := &jsonrpc.Response{
12091212
Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()),
12101213
}
@@ -1829,7 +1832,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
18291832
}
18301833
// Keep this after the setMCPHeaders call to ensure that the
18311834
// protocol version header is set.
1832-
setStandardHeaders(req.Header, msg)
1835+
setStandardHeaders(ctx, req.Header, msg)
18331836
resp, err := c.client.Do(req)
18341837
if err != nil {
18351838
// Any error from client.Do means the request didn't reach the server.

0 commit comments

Comments
 (0)