-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_elicitation.py
More file actions
392 lines (316 loc) · 16.9 KB
/
test_elicitation.py
File metadata and controls
392 lines (316 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
"""Test the elicitation feature using stdio transport."""
from typing import Any
import pytest
from pydantic import BaseModel, Field
from mcp import Client, types
from mcp.client.session import ClientSession, ElicitationFnT
from mcp.server.mcpserver import Context, MCPServer
from mcp.server.session import ServerSession
from mcp.shared._context import RequestContext
from mcp.types import ElicitRequestParams, ElicitResult, TextContent
# Shared schema for basic tests
class AnswerSchema(BaseModel):
answer: str = Field(description="The user's answer to the question")
def create_ask_user_tool(mcp: MCPServer):
"""Create a standard ask_user tool that handles all elicitation responses."""
@mcp.tool(description="A tool that uses elicitation")
async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema)
if result.action == "accept" and result.data:
return f"User answered: {result.data.answer}"
elif result.action == "decline":
return "User declined to answer"
else: # pragma: no cover
return "User cancelled"
return ask_user
async def call_tool_and_assert(
mcp: MCPServer,
elicitation_callback: ElicitationFnT,
tool_name: str,
args: dict[str, Any],
expected_text: str | None = None,
text_contains: list[str] | None = None,
):
"""Helper to create session, call tool, and assert result."""
async with Client(mcp, elicitation_callback=elicitation_callback) as client:
result = await client.call_tool(tool_name, args)
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
if expected_text is not None:
assert result.content[0].text == expected_text
elif text_contains is not None: # pragma: no branch
for substring in text_contains:
assert substring in result.content[0].text
return result
@pytest.mark.anyio
async def test_stdio_elicitation():
"""Test the elicitation feature using stdio transport."""
mcp = MCPServer(name="StdioElicitationServer")
create_ask_user_tool(mcp)
# Create a custom handler for elicitation requests
async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
if params.message == "Tool wants to ask: What is your name?":
return ElicitResult(action="accept", content={"answer": "Test User"})
else: # pragma: no cover
raise ValueError(f"Unexpected elicitation message: {params.message}")
await call_tool_and_assert(
mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User answered: Test User"
)
@pytest.mark.anyio
async def test_stdio_elicitation_decline():
"""Test elicitation with user declining."""
mcp = MCPServer(name="StdioElicitationDeclineServer")
create_ask_user_tool(mcp)
async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
return ElicitResult(action="decline")
await call_tool_and_assert(
mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User declined to answer"
)
@pytest.mark.anyio
async def test_elicitation_schema_validation():
"""Test that elicitation schemas must only contain primitive types."""
mcp = MCPServer(name="ValidationTestServer")
def create_validation_tool(name: str, schema_class: type[BaseModel]):
@mcp.tool(name=name, description=f"Tool testing {name}")
async def tool(ctx: Context[ServerSession, None]) -> str:
try:
await ctx.elicit(message="This should fail validation", schema=schema_class)
return "Should not reach here" # pragma: no cover
except TypeError as e:
return f"Validation failed as expected: {str(e)}"
return tool
# Test cases for invalid schemas
class InvalidListSchema(BaseModel):
numbers: list[int] = Field(description="List of numbers")
class NestedModel(BaseModel):
value: str
class InvalidNestedSchema(BaseModel):
nested: NestedModel = Field(description="Nested model")
create_validation_tool("invalid_list", InvalidListSchema)
create_validation_tool("nested_model", InvalidNestedSchema)
# Dummy callback (won't be called due to validation failure)
async def elicitation_callback(
context: RequestContext[ClientSession], params: ElicitRequestParams
): # pragma: no cover
return ElicitResult(action="accept", content={})
async with Client(mcp, elicitation_callback=elicitation_callback) as client:
# Test both invalid schemas
for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]:
result = await client.call_tool(tool_name, {})
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert "Validation failed as expected" in result.content[0].text
assert field_name in result.content[0].text
@pytest.mark.anyio
async def test_elicitation_with_optional_fields():
"""Test that Optional fields work correctly in elicitation schemas."""
mcp = MCPServer(name="OptionalFieldServer")
class OptionalSchema(BaseModel):
required_name: str = Field(description="Your name (required)")
optional_age: int | None = Field(default=None, description="Your age (optional)")
optional_email: str | None = Field(default=None, description="Your email (optional)")
subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?")
@mcp.tool(description="Tool with optional fields")
async def optional_tool(ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema)
if result.action == "accept" and result.data:
info = [f"Name: {result.data.required_name}"]
if result.data.optional_age is not None:
info.append(f"Age: {result.data.optional_age}")
if result.data.optional_email is not None:
info.append(f"Email: {result.data.optional_email}")
info.append(f"Subscribe: {result.data.subscribe}")
return ", ".join(info)
else: # pragma: no cover
return f"User {result.action}"
# Test cases with different field combinations
test_cases: list[tuple[dict[str, Any], str]] = [
(
# All fields provided
{"required_name": "John Doe", "optional_age": 30, "optional_email": "john@example.com", "subscribe": True},
"Name: John Doe, Age: 30, Email: john@example.com, Subscribe: True",
),
(
# Only required fields
{"required_name": "Jane Smith"},
"Name: Jane Smith, Subscribe: False",
),
]
for content, expected in test_cases:
async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
return ElicitResult(action="accept", content=content)
await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected)
# Test invalid optional field
class InvalidOptionalSchema(BaseModel):
name: str = Field(description="Name")
optional_list: list[int] | None = Field(default=None, description="Invalid optional list")
@mcp.tool(description="Tool with invalid optional field")
async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str:
try:
await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema)
return "Should not reach here" # pragma: no cover
except TypeError as e:
return f"Validation failed: {str(e)}"
async def elicitation_callback(
context: RequestContext[ClientSession], params: ElicitRequestParams
): # pragma: no cover
return ElicitResult(action="accept", content={})
await call_tool_and_assert(
mcp,
elicitation_callback,
"invalid_optional_tool",
{},
text_contains=["Validation failed:", "optional_list"],
)
# Test valid list[str] for multi-select enum
class ValidMultiSelectSchema(BaseModel):
name: str = Field(description="Name")
tags: list[str] = Field(description="Tags")
@mcp.tool(description="Tool with valid list[str] field")
async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema)
if result.action == "accept" and result.data:
return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}"
return f"User {result.action}" # pragma: no cover
async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
if "Please provide tags" in params.message:
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
return ElicitResult(action="decline") # pragma: no cover
await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2")
# Test Optional[list[str]] for optional multi-select enum
class OptionalMultiSelectSchema(BaseModel):
name: str = Field(description="Name")
tags: list[str] | None = Field(default=None, description="Optional tags")
@mcp.tool(description="Tool with optional list[str] field")
async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message="Please provide optional tags", schema=OptionalMultiSelectSchema)
if result.action == "accept" and result.data:
tags_str = ", ".join(result.data.tags) if result.data.tags else "none"
return f"Name: {result.data.name}, Tags: {tags_str}"
return f"User {result.action}" # pragma: no cover
async def optional_multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
if "Please provide optional tags" in params.message:
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
return ElicitResult(action="decline") # pragma: no cover
await call_tool_and_assert(
mcp, optional_multiselect_callback, "optional_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2"
)
@pytest.mark.anyio
async def test_elicitation_with_default_values():
"""Test that default values work correctly in elicitation schemas and are included in JSON."""
mcp = MCPServer(name="DefaultValuesServer")
class DefaultsSchema(BaseModel):
name: str = Field(default="Guest", description="User name")
age: int = Field(default=18, description="User age")
subscribe: bool = Field(default=True, description="Subscribe to newsletter")
email: str = Field(description="Email address (required)")
@mcp.tool(description="Tool with default values")
async def defaults_tool(ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema)
if result.action == "accept" and result.data:
return (
f"Name: {result.data.name}, Age: {result.data.age}, "
f"Subscribe: {result.data.subscribe}, Email: {result.data.email}"
)
else: # pragma: no cover
return f"User {result.action}"
# First verify that defaults are present in the JSON schema sent to clients
async def callback_schema_verify(context: RequestContext[ClientSession], params: ElicitRequestParams):
# Verify the schema includes defaults
assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation"
schema = params.requested_schema
props = schema["properties"]
assert props["name"]["default"] == "Guest"
assert props["age"]["default"] == 18
assert props["subscribe"]["default"] is True
assert "default" not in props["email"] # Required field has no default
return ElicitResult(action="accept", content={"email": "test@example.com"})
await call_tool_and_assert(
mcp,
callback_schema_verify,
"defaults_tool",
{},
"Name: Guest, Age: 18, Subscribe: True, Email: test@example.com",
)
# Test overriding defaults
async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams):
return ElicitResult(
action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False}
)
await call_tool_and_assert(
mcp, callback_override, "defaults_tool", {}, "Name: John, Age: 25, Subscribe: False, Email: john@example.com"
)
@pytest.mark.anyio
async def test_elicitation_with_enum_titles():
"""Test elicitation with enum schemas using oneOf/anyOf for titles."""
mcp = MCPServer(name="ColorPreferencesApp")
# Test single-select with titles using oneOf
class FavoriteColorSchema(BaseModel):
user_name: str = Field(description="Your name")
favorite_color: str = Field(
description="Select your favorite color",
json_schema_extra={
"oneOf": [
{"const": "red", "title": "Red"},
{"const": "green", "title": "Green"},
{"const": "blue", "title": "Blue"},
{"const": "yellow", "title": "Yellow"},
]
},
)
@mcp.tool(description="Single color selection")
async def select_favorite_color(ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema)
if result.action == "accept" and result.data:
return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}"
return f"User {result.action}" # pragma: no cover
# Test multi-select with titles using anyOf
class FavoriteColorsSchema(BaseModel):
user_name: str = Field(description="Your name")
favorite_colors: list[str] = Field(
description="Select your favorite colors",
json_schema_extra={
"items": {
"anyOf": [
{"const": "red", "title": "Red"},
{"const": "green", "title": "Green"},
{"const": "blue", "title": "Blue"},
{"const": "yellow", "title": "Yellow"},
]
}
},
)
@mcp.tool(description="Multiple color selection")
async def select_favorite_colors(ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema)
if result.action == "accept" and result.data:
return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}"
return f"User {result.action}" # pragma: no cover
# Test legacy enumNames format
class LegacyColorSchema(BaseModel):
user_name: str = Field(description="Your name")
color: str = Field(
description="Select a color",
json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]},
)
@mcp.tool(description="Legacy enum format")
async def select_color_legacy(ctx: Context[ServerSession, None]) -> str:
result = await ctx.elicit(message="Select a color (legacy format)", schema=LegacyColorSchema)
if result.action == "accept" and result.data:
return f"User: {result.data.user_name}, Color: {result.data.color}"
return f"User {result.action}" # pragma: no cover
async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
if "colors" in params.message and "legacy" not in params.message:
return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]})
elif "color" in params.message:
if "legacy" in params.message:
return ElicitResult(action="accept", content={"user_name": "Charlie", "color": "green"})
else:
return ElicitResult(action="accept", content={"user_name": "Alice", "favorite_color": "blue"})
return ElicitResult(action="decline") # pragma: no cover
# Test single-select with titles
await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue")
# Test multi-select with titles
await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green")
# Test legacy enumNames format
await call_tool_and_assert(mcp, enum_callback, "select_color_legacy", {}, "User: Charlie, Color: green")