Skip to content

Commit 9dd6907

Browse files
committed
fix: add mutex protection to toolCache and provide thread-safe accessors in ClientSession
1 parent 24cc607 commit 9dd6907

3 files changed

Lines changed: 85 additions & 12 deletions

File tree

mcp/client.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ type ClientSession struct {
325325
pendingElicitationsMu sync.Mutex
326326
pendingElicitations map[string]chan struct{}
327327

328+
// toolCacheMu guards toolCache.
329+
toolCacheMu sync.RWMutex
328330
// toolCache stores tool definitions keyed by name.
329331
// It is used to look up x-mcp-header annotations when
330332
// constructing Mcp-Param-* headers for tools/call requests.
@@ -375,6 +377,8 @@ func (cs *ClientSession) Wait() error {
375377
}
376378

377379
func (cs *ClientSession) cacheTools(tools []*Tool) {
380+
cs.toolCacheMu.Lock()
381+
defer cs.toolCacheMu.Unlock()
378382
if cs.toolCache == nil {
379383
cs.toolCache = make(map[string]*Tool, len(tools))
380384
}
@@ -383,6 +387,12 @@ func (cs *ClientSession) cacheTools(tools []*Tool) {
383387
}
384388
}
385389

390+
func (cs *ClientSession) getCachedTool(name string) *Tool {
391+
cs.toolCacheMu.RLock()
392+
defer cs.toolCacheMu.RUnlock()
393+
return cs.toolCache[name]
394+
}
395+
386396
// registerElicitationWaiter registers a waiter for an elicitation complete
387397
// notification with the given elicitation ID. It returns two functions: an await
388398
// function that waits for the notification or context cancellation, and a cleanup
@@ -1021,7 +1031,7 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (
10211031
// Avoid sending nil over the wire.
10221032
params.Arguments = map[string]any{}
10231033
}
1024-
if tool := cs.toolCache[params.Name]; tool != nil {
1034+
if tool := cs.getCachedTool(params.Name); tool != nil {
10251035
ctx = context.WithValue(ctx, toolContextKey, tool)
10261036
}
10271037
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))

mcp/client_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,80 @@ func TestClientCapabilities(t *testing.T) {
440440
}
441441
}
442442

443+
func TestToolCache(t *testing.T) {
444+
tool1 := &Tool{Name: "tool1", Description: "first"}
445+
tool2 := &Tool{Name: "tool2", Description: "second"}
446+
tool1Updated := &Tool{Name: "tool1", Description: "updated"}
447+
448+
testCases := []struct {
449+
name string
450+
cacheBatches [][]*Tool
451+
lookup string
452+
want *Tool
453+
}{
454+
{
455+
name: "empty cache",
456+
lookup: "tool1",
457+
want: nil,
458+
},
459+
{
460+
name: "single tool found",
461+
cacheBatches: [][]*Tool{{tool1}},
462+
lookup: "tool1",
463+
want: tool1,
464+
},
465+
{
466+
name: "unknown tool",
467+
cacheBatches: [][]*Tool{{tool1}},
468+
lookup: "nonexistent",
469+
want: nil,
470+
},
471+
{
472+
name: "multiple tools single batch",
473+
cacheBatches: [][]*Tool{{tool1, tool2}},
474+
lookup: "tool2",
475+
want: tool2,
476+
},
477+
{
478+
name: "additive first tool retained",
479+
cacheBatches: [][]*Tool{{tool1}, {tool2}},
480+
lookup: "tool1",
481+
want: tool1,
482+
},
483+
{
484+
name: "additive second tool added",
485+
cacheBatches: [][]*Tool{{tool1}, {tool2}},
486+
lookup: "tool2",
487+
want: tool2,
488+
},
489+
{
490+
name: "overwrite existing entry",
491+
cacheBatches: [][]*Tool{{tool1}, {tool1Updated}},
492+
lookup: "tool1",
493+
want: tool1Updated,
494+
},
495+
{
496+
name: "empty batch no-op",
497+
cacheBatches: [][]*Tool{{}},
498+
lookup: "tool1",
499+
want: nil,
500+
},
501+
}
502+
503+
for _, tc := range testCases {
504+
t.Run(tc.name, func(t *testing.T) {
505+
cs := &ClientSession{}
506+
for _, batch := range tc.cacheBatches {
507+
cs.cacheTools(batch)
508+
}
509+
got := cs.getCachedTool(tc.lookup)
510+
if diff := cmp.Diff(tc.want, got); diff != "" {
511+
t.Errorf("getCachedTool(%q) mismatch (-want +got):\n%s", tc.lookup, diff)
512+
}
513+
})
514+
}
515+
}
516+
443517
func TestClientCapabilitiesOverWire(t *testing.T) {
444518
testCases := []struct {
445519
name string

mcp/streamable_headers.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ const (
2727
mcpHeaderExtension = "x-mcp-header"
2828
)
2929

30-
// ---------------------------------------------------------------------------
31-
// Shared helpers (used by both client and server)
32-
// ---------------------------------------------------------------------------
3330

3431
func extractName(method string, params json.RawMessage) (string, bool) {
3532
switch method {
@@ -122,10 +119,6 @@ func unmarshalPrimitive(raw json.RawMessage) any {
122119
}
123120
}
124121

125-
// ---------------------------------------------------------------------------
126-
// Client-side helpers
127-
// ---------------------------------------------------------------------------
128-
129122
// setStandardHeaders populates standard MCP headers.
130123
// It requires the protocol version header to be set.
131124
func setStandardHeaders(header http.Header, msg jsonrpc.Message) {
@@ -279,10 +272,6 @@ func validateHeaderName(name string) error {
279272
return nil
280273
}
281274

282-
// ---------------------------------------------------------------------------
283-
// Server-side helpers
284-
// ---------------------------------------------------------------------------
285-
286275
func validateMcpHeaders(header http.Header, msg jsonrpc.Message, tool *Tool) error {
287276
protocolVersion := header.Get(protocolVersionHeader)
288277
if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders {

0 commit comments

Comments
 (0)