|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import contextvars |
| 6 | +from collections.abc import Iterator |
| 7 | +from contextlib import contextmanager |
6 | 8 | from unittest.mock import patch |
7 | 9 |
|
8 | 10 | import anyio |
@@ -323,24 +325,29 @@ async def test_client_uses_transport_directly(app: MCPServer): |
323 | 325 | ) |
324 | 326 |
|
325 | 327 |
|
| 328 | +_TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial") |
| 329 | + |
| 330 | + |
| 331 | +@contextmanager |
| 332 | +def _set_test_contextvar(value: str) -> Iterator[None]: |
| 333 | + token = _TEST_CONTEXTVAR.set(value) |
| 334 | + try: |
| 335 | + yield |
| 336 | + finally: |
| 337 | + _TEST_CONTEXTVAR.reset(token) |
| 338 | + |
| 339 | + |
326 | 340 | async def test_context_propagation(): |
327 | 341 | """Sender's contextvars.Context is propagated to the server handler.""" |
328 | | - trace_id = contextvars.ContextVar[str]("trace_id", default="unset") |
329 | | - captured: list[str] = [] |
330 | | - |
331 | 342 | server = MCPServer("test") |
332 | 343 |
|
333 | 344 | @server.tool() |
334 | | - def check_context() -> str: |
335 | | - """Return the trace_id contextvar value.""" |
336 | | - value = trace_id.get() |
337 | | - captured.append(value) |
338 | | - return value |
339 | | - |
340 | | - trace_id.set("test-trace-123") |
| 345 | + async def check_context() -> str: |
| 346 | + """Return the contextvar value visible to the handler.""" |
| 347 | + return _TEST_CONTEXTVAR.get() |
341 | 348 |
|
342 | 349 | async with Client(server) as client: |
343 | | - result = await client.call_tool("check_context", {}) |
| 350 | + with _set_test_contextvar("client_value"): |
| 351 | + result = await client.call_tool("check_context", {}) |
344 | 352 |
|
345 | | - assert captured == ["test-trace-123"] |
346 | | - assert result.content[0].text == "test-trace-123" # type: ignore[union-attr] |
| 353 | + assert result.content[0].text == "client_value" # type: ignore[union-attr] |
0 commit comments