-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathclient.py
More file actions
367 lines (296 loc) · 14.3 KB
/
client.py
File metadata and controls
367 lines (296 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
"""MCP unified conformance test client.
This client is designed to work with the @modelcontextprotocol/conformance npm package.
It handles all conformance test scenarios via environment variables and CLI arguments.
Contract:
- MCP_CONFORMANCE_SCENARIO env var -> scenario name
- MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for client-credentials scenarios)
- Server URL as last CLI argument (sys.argv[1])
- Must exit 0 within 30 seconds
Scenarios:
initialize - Connect, initialize, list tools, close
tools_call - Connect, call add_numbers(a=5, b=3), close
sse-retry - Connect, call test_reconnection, close
elicitation-sep1034-client-defaults - Elicitation with default accept callback
auth/client-credentials-jwt - Client credentials with private_key_jwt
auth/client-credentials-basic - Client credentials with client_secret_basic
auth/* - Authorization code flow (default for auth scenarios)
"""
import asyncio
import json
import logging
import os
import sys
from collections.abc import Callable, Coroutine
from typing import Any, cast
from urllib.parse import parse_qs, urlparse
import httpx
from pydantic import AnyUrl
from mcp import ClientSession, types
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.auth.extensions.client_credentials import (
ClientCredentialsOAuthProvider,
PrivateKeyJWTOAuthProvider,
SignedJWTParameters,
)
from mcp.client.context import ClientRequestContext
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
# Set up logging to stderr (stdout is for conformance test output)
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
stream=sys.stderr,
)
logger = logging.getLogger(__name__)
# Type for async scenario handler functions
ScenarioHandler = Callable[[str], Coroutine[Any, None, None]]
# Registry of scenario handlers
HANDLERS: dict[str, ScenarioHandler] = {}
def register(name: str) -> Callable[[ScenarioHandler], ScenarioHandler]:
"""Register a scenario handler."""
def decorator(fn: ScenarioHandler) -> ScenarioHandler:
HANDLERS[name] = fn
return fn
return decorator
def get_conformance_context() -> dict[str, Any]:
"""Load conformance test context from MCP_CONFORMANCE_CONTEXT environment variable."""
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
if not context_json:
raise RuntimeError(
"MCP_CONFORMANCE_CONTEXT environment variable not set. "
"Expected JSON with client_id, client_secret, and/or private_key_pem."
)
try:
return json.loads(context_json)
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse MCP_CONFORMANCE_CONTEXT as JSON: {e}") from e
class InMemoryTokenStorage(TokenStorage):
"""Simple in-memory token storage for conformance testing."""
def __init__(self) -> None:
self._tokens: OAuthToken | None = None
self._client_info: OAuthClientInformationFull | None = None
async def get_tokens(self) -> OAuthToken | None:
return self._tokens
async def set_tokens(self, tokens: OAuthToken) -> None:
self._tokens = tokens
async def get_client_info(self) -> OAuthClientInformationFull | None:
return self._client_info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
self._client_info = client_info
class ConformanceOAuthCallbackHandler:
"""OAuth callback handler that automatically fetches the authorization URL
and extracts the auth code, without requiring user interaction.
"""
def __init__(self) -> None:
self._auth_code: str | None = None
self._state: str | None = None
async def handle_redirect(self, authorization_url: str) -> None:
"""Fetch the authorization URL and extract the auth code from the redirect."""
logger.debug(f"Fetching authorization URL: {authorization_url}")
async with httpx.AsyncClient() as client:
response = await client.get(
authorization_url,
follow_redirects=False,
)
if response.status_code in (301, 302, 303, 307, 308):
location = cast(str, response.headers.get("location"))
if location:
redirect_url = urlparse(location)
query_params: dict[str, list[str]] = parse_qs(redirect_url.query)
if "code" in query_params:
self._auth_code = query_params["code"][0]
state_values = query_params.get("state")
self._state = state_values[0] if state_values else None
logger.debug(f"Got auth code from redirect: {self._auth_code[:10]}...")
return
else:
raise RuntimeError(f"No auth code in redirect URL: {location}")
else:
raise RuntimeError(f"No redirect location received from {authorization_url}")
else:
raise RuntimeError(f"Expected redirect response, got {response.status_code} from {authorization_url}")
async def handle_callback(self) -> tuple[str, str | None]:
"""Return the captured auth code and state."""
if self._auth_code is None:
raise RuntimeError("No authorization code available - was handle_redirect called?")
auth_code = self._auth_code
state = self._state
self._auth_code = None
self._state = None
return auth_code, state
# --- Scenario Handlers ---
@register("initialize")
async def run_initialize(server_url: str) -> None:
"""Connect, initialize, list tools, close."""
async with streamable_http_client(url=server_url) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
logger.debug("Initialized successfully")
await session.list_tools()
logger.debug("Listed tools successfully")
@register("tools_call")
async def run_tools_call(server_url: str) -> None:
"""Connect, initialize, list tools, call add_numbers(a=5, b=3), close."""
async with streamable_http_client(url=server_url) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
await session.list_tools()
result = await session.call_tool("add_numbers", {"a": 5, "b": 3})
logger.debug(f"add_numbers result: {result}")
@register("sse-retry")
async def run_sse_retry(server_url: str) -> None:
"""Connect, initialize, list tools, call test_reconnection, close."""
async with streamable_http_client(url=server_url) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
await session.list_tools()
result = await session.call_tool("test_reconnection", {})
logger.debug(f"test_reconnection result: {result}")
async def default_elicitation_callback(
context: ClientRequestContext,
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
"""Accept elicitation and apply defaults from the schema (SEP-1034)."""
content: dict[str, str | int | float | bool | list[str] | None] = {}
# For form mode, extract defaults from the requested_schema
if isinstance(params, types.ElicitRequestFormParams):
schema = params.requested_schema
logger.debug(f"Elicitation schema: {schema}")
properties = schema.get("properties", {})
for prop_name, prop_schema in properties.items():
if "default" in prop_schema:
content[prop_name] = prop_schema["default"]
logger.debug(f"Applied defaults: {content}")
return types.ElicitResult(action="accept", content=content)
@register("elicitation-sep1034-client-defaults")
async def run_elicitation_defaults(server_url: str) -> None:
"""Connect with elicitation callback that applies schema defaults."""
async with streamable_http_client(url=server_url) as (read_stream, write_stream):
async with ClientSession(
read_stream, write_stream, elicitation_callback=default_elicitation_callback
) as session:
await session.initialize()
await session.list_tools()
result = await session.call_tool("test_client_elicitation_defaults", {})
logger.debug(f"test_client_elicitation_defaults result: {result}")
@register("auth/client-credentials-jwt")
async def run_client_credentials_jwt(server_url: str) -> None:
"""Client credentials flow with private_key_jwt authentication."""
context = get_conformance_context()
client_id = context.get("client_id")
private_key_pem = context.get("private_key_pem")
signing_algorithm = context.get("signing_algorithm", "ES256")
if not client_id:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_id'")
if not private_key_pem:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'private_key_pem'")
jwt_params = SignedJWTParameters(
issuer=client_id,
subject=client_id,
signing_algorithm=signing_algorithm,
signing_key=private_key_pem,
)
oauth_auth = PrivateKeyJWTOAuthProvider(
server_url=server_url,
storage=InMemoryTokenStorage(),
client_id=client_id,
assertion_provider=jwt_params.create_assertion_provider(),
)
await _run_auth_session(server_url, oauth_auth)
@register("auth/client-credentials-basic")
async def run_client_credentials_basic(server_url: str) -> None:
"""Client credentials flow with client_secret_basic authentication."""
context = get_conformance_context()
client_id = context.get("client_id")
client_secret = context.get("client_secret")
if not client_id:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_id'")
if not client_secret:
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_secret'")
oauth_auth = ClientCredentialsOAuthProvider(
server_url=server_url,
storage=InMemoryTokenStorage(),
client_id=client_id,
client_secret=client_secret,
token_endpoint_auth_method="client_secret_basic",
)
await _run_auth_session(server_url, oauth_auth)
async def run_auth_code_client(server_url: str) -> None:
"""Authorization code flow (default for auth/* scenarios)."""
callback_handler = ConformanceOAuthCallbackHandler()
storage = InMemoryTokenStorage()
# Check for pre-registered client credentials from context
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
if context_json:
try:
context = json.loads(context_json)
client_id = context.get("client_id")
client_secret = context.get("client_secret")
if client_id:
await storage.set_client_info(
OAuthClientInformationFull(
client_id=client_id,
client_secret=client_secret,
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
token_endpoint_auth_method="client_secret_basic" if client_secret else "none",
)
)
logger.debug(f"Pre-loaded client credentials: client_id={client_id}")
except json.JSONDecodeError:
pass
oauth_auth = OAuthClientProvider(
server_url=server_url,
client_metadata=OAuthClientMetadata(
client_name="conformance-client",
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
),
storage=storage,
redirect_handler=callback_handler.handle_redirect,
callback_handler=callback_handler.handle_callback,
client_metadata_url="https://conformance-test.local/client-metadata.json",
)
await _run_auth_session(server_url, oauth_auth)
async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None:
"""Common session logic for all OAuth flows."""
client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0)
async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream):
async with ClientSession(
read_stream, write_stream, elicitation_callback=default_elicitation_callback
) as session:
await session.initialize()
logger.debug("Initialized successfully")
tools_result = await session.list_tools()
logger.debug(f"Listed tools: {[t.name for t in tools_result.tools]}")
# Call the first available tool (different tests have different tools)
if tools_result.tools:
tool_name = tools_result.tools[0].name
try:
result = await session.call_tool(tool_name, {})
logger.debug(f"Called {tool_name}, result: {result}")
except Exception as e:
logger.debug(f"Tool call result/error: {e}")
logger.debug("Connection closed successfully")
def main() -> None:
"""Main entry point for the conformance client."""
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <server-url>", file=sys.stderr)
sys.exit(1)
server_url = sys.argv[1]
scenario = os.environ.get("MCP_CONFORMANCE_SCENARIO")
if scenario:
logger.debug(f"Running explicit scenario '{scenario}' against {server_url}")
handler = HANDLERS.get(scenario)
if handler:
asyncio.run(handler(server_url))
elif scenario.startswith("auth/"):
asyncio.run(run_auth_code_client(server_url))
else:
print(f"Unknown scenario: {scenario}", file=sys.stderr)
sys.exit(1)
else:
logger.debug(f"Running default auth flow against {server_url}")
asyncio.run(run_auth_code_client(server_url))
if __name__ == "__main__":
main()