Skip to content

Commit 96e62af

Browse files
committed
fix: enforce server-to-client request scope by restricting client request ID context access
1 parent 93a41b2 commit 96e62af

8 files changed

Lines changed: 283 additions & 94 deletions

mcp/client_example_test.go

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,36 @@ func Example_sampling() {
7878
},
7979
})
8080

81-
// Connect the server and client...
81+
// Create a server with a tool that uses sampling.
82+
// Server-to-client requests like CreateMessage must be made within a
83+
// client request handler (e.g. a tool call).
8284
ct, st := mcp.NewInMemoryTransports()
8385
s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil)
86+
s.AddTool(&mcp.Tool{Name: "sample", InputSchema: map[string]any{"type": "object"}}, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
87+
msg, err := req.Session.CreateMessage(ctx, &mcp.CreateMessageParams{})
88+
if err != nil {
89+
return nil, err
90+
}
91+
return &mcp.CallToolResult{
92+
Content: []mcp.Content{&mcp.TextContent{Text: msg.Content.(*mcp.TextContent).Text}},
93+
}, nil
94+
})
8495
session, err := s.Connect(ctx, st, nil)
8596
if err != nil {
8697
log.Fatal(err)
8798
}
8899
defer session.Close()
89100

90-
if _, err := c.Connect(ctx, ct, nil); err != nil {
101+
cs, err := c.Connect(ctx, ct, nil)
102+
if err != nil {
91103
log.Fatal(err)
92104
}
93105

94-
msg, err := session.CreateMessage(ctx, &mcp.CreateMessageParams{})
106+
res, err := cs.CallTool(ctx, &mcp.CallToolParams{Name: "sample"})
95107
if err != nil {
96108
log.Fatal(err)
97109
}
98-
fmt.Println(msg.Content.(*mcp.TextContent).Text)
110+
fmt.Println(res.Content[0].(*mcp.TextContent).Text)
99111
// Output: would have created a message
100112
}
101113

@@ -107,7 +119,27 @@ func Example_elicitation() {
107119
ctx := context.Background()
108120
ct, st := mcp.NewInMemoryTransports()
109121

122+
// Create a server with a tool that uses elicitation.
123+
// Server-to-client requests like Elicit must be made within a client
124+
// request handler (e.g. a tool call).
110125
s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil)
126+
s.AddTool(&mcp.Tool{Name: "ask", InputSchema: map[string]any{"type": "object"}}, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
127+
res, err := req.Session.Elicit(ctx, &mcp.ElicitParams{
128+
Message: "Please provide information",
129+
RequestedSchema: &jsonschema.Schema{
130+
Type: "object",
131+
Properties: map[string]*jsonschema.Schema{
132+
"test": {Type: "string"},
133+
},
134+
},
135+
})
136+
if err != nil {
137+
return nil, err
138+
}
139+
return &mcp.CallToolResult{
140+
Content: []mcp.Content{&mcp.TextContent{Text: res.Content["test"].(string)}},
141+
}, nil
142+
})
111143
ss, err := s.Connect(ctx, st, nil)
112144
if err != nil {
113145
log.Fatal(err)
@@ -119,22 +151,16 @@ func Example_elicitation() {
119151
return &mcp.ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil
120152
},
121153
})
122-
if _, err := c.Connect(ctx, ct, nil); err != nil {
154+
cs, err := c.Connect(ctx, ct, nil)
155+
if err != nil {
123156
log.Fatal(err)
124157
}
125-
res, err := ss.Elicit(ctx, &mcp.ElicitParams{
126-
Message: "This should fail",
127-
RequestedSchema: &jsonschema.Schema{
128-
Type: "object",
129-
Properties: map[string]*jsonschema.Schema{
130-
"test": {Type: "string"},
131-
},
132-
},
133-
})
158+
159+
toolRes, err := cs.CallTool(ctx, &mcp.CallToolParams{Name: "ask"})
134160
if err != nil {
135161
log.Fatal(err)
136162
}
137-
fmt.Println(res.Content["test"])
163+
fmt.Println(toolRes.Content[0].(*mcp.TextContent).Text)
138164
// Output: value
139165
}
140166

mcp/elicitation_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func TestElicitationURLMode(t *testing.T) {
120120
}
121121
defer cs.Close()
122122

123-
result, err := ss.Elicit(ctx, tc.params)
123+
result, err := ss.Elicit(clientRequestCtx(ctx), tc.params)
124124

125125
if tc.wantErrMsg != "" {
126126
if err == nil || !strings.Contains(err.Error(), tc.wantErrMsg) {
@@ -170,7 +170,7 @@ func TestElicitationCompleteNotification(t *testing.T) {
170170

171171
// 1. Server initiates a URL elicitation
172172
elicitID := "testElicitationID-123"
173-
resp, err := ss.Elicit(ctx, &ElicitParams{
173+
resp, err := ss.Elicit(clientRequestCtx(ctx), &ElicitParams{
174174
Mode: "url",
175175
Message: "Please complete this form: ",
176176
URL: "https://example.com/form?id=" + elicitID,
@@ -251,7 +251,7 @@ func TestElicitationNoValidationWithoutAccept(t *testing.T) {
251251
}
252252
defer cs.Close()
253253

254-
res, err := ss.Elicit(ctx, &ElicitParams{
254+
res, err := ss.Elicit(clientRequestCtx(ctx), &ElicitParams{
255255
Message: "Test bug",
256256
RequestedSchema: schema,
257257
})

0 commit comments

Comments
 (0)