Skip to content

Commit 70854a4

Browse files
committed
Test cases
See if tests still fail
1 parent 239d682 commit 70854a4

2 files changed

Lines changed: 281 additions & 92 deletions

File tree

tests/test_context_propagation.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import contextvars
2+
import multiprocessing
3+
import socket
4+
from collections.abc import Iterator
5+
from contextlib import contextmanager
6+
from typing import Literal, assert_never
7+
8+
import httpx
9+
import pytest
10+
import uvicorn
11+
from inline_snapshot import snapshot
12+
from starlette.applications import Starlette
13+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
14+
from starlette.requests import Request
15+
from starlette.responses import Response
16+
17+
import mcp.types as types
18+
from mcp import Client
19+
from mcp.client.sse import sse_client
20+
from mcp.client.streamable_http import streamable_http_client
21+
from mcp.server import MCPServer
22+
from tests.test_helpers import wait_for_server
23+
24+
TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial")
25+
26+
27+
@contextmanager
28+
def set_test_contextvar(value: str) -> Iterator[None]:
29+
token = TEST_CONTEXTVAR.set(value)
30+
try:
31+
yield
32+
finally:
33+
TEST_CONTEXTVAR.reset(token)
34+
35+
36+
# Sends header CLIENT_HEADER with a configured value
37+
class SendClientHeaderTransport(httpx.AsyncHTTPTransport):
38+
def __init__(self) -> None:
39+
super().__init__()
40+
self.client_header_value: str = "initial"
41+
42+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
43+
request.headers["CLIENT_HEADER"] = self.client_header_value
44+
return await super().handle_async_request(request)
45+
46+
47+
# Intercepts the httpx call to capture the contextvar's value
48+
class ContextCapturingTransport(httpx.AsyncHTTPTransport):
49+
def __init__(self):
50+
super().__init__()
51+
self.captured_context_var: str | None = None
52+
53+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
54+
self.captured_context_var = TEST_CONTEXTVAR.get()
55+
return await super().handle_async_request(request)
56+
57+
58+
def create_server() -> MCPServer:
59+
mcp = MCPServer("test_server")
60+
61+
# tool that returns the value of TEST_CONTEXT_VAR.
62+
@mcp.tool()
63+
async def my_tool() -> str:
64+
return TEST_CONTEXTVAR.get()
65+
66+
return mcp
67+
68+
69+
@pytest.fixture
70+
def server_port() -> int:
71+
with socket.socket() as s:
72+
s.bind(("127.0.0.1", 0))
73+
return s.getsockname()[1]
74+
75+
76+
def run_server(transport: Literal["sse", "streamable_http"], port: int): # pragma: no cover
77+
class ContextVarMiddleware(BaseHTTPMiddleware): # pragma: lax no cover
78+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
79+
actual_value = request.headers.get("CLIENT_HEADER")
80+
with set_test_contextvar(f"from middleware CLIENT_HEADER={actual_value}"):
81+
return await call_next(request)
82+
83+
server = create_server()
84+
85+
match transport:
86+
case "sse":
87+
app = server.sse_app(host="127.0.0.1")
88+
case "streamable_http":
89+
app = server.streamable_http_app(host="127.0.0.1")
90+
case _:
91+
assert_never(transport)
92+
93+
app.add_middleware(ContextVarMiddleware)
94+
95+
uvicorn.run(app, host="127.0.0.1", port=port, log_level="error")
96+
97+
98+
@contextmanager
99+
def start_server_process(transport: Literal["sse", "streamable_http"], port: int):
100+
"""Start server in a separate process."""
101+
process = multiprocessing.Process(target=run_server, args=(transport, port))
102+
103+
process.start()
104+
try:
105+
wait_for_server(port)
106+
yield process
107+
finally:
108+
process.terminate()
109+
process.join()
110+
111+
112+
@pytest.mark.anyio
113+
async def test_memory_transport_client_to_server():
114+
async with Client(create_server()) as client:
115+
with set_test_contextvar("client_value"):
116+
result = await client.call_tool(name="my_tool")
117+
118+
assert isinstance(result, types.CallToolResult)
119+
assert result.content == snapshot([types.TextContent(text="client_value")])
120+
121+
122+
@pytest.mark.anyio
123+
async def test_streamable_http_asgi_to_mcpserver(server_port: int):
124+
with start_server_process("streamable_http", server_port):
125+
async with (
126+
SendClientHeaderTransport() as transport,
127+
httpx.AsyncClient(transport=transport) as http_client,
128+
Client(streamable_http_client(f"http://127.0.0.1:{server_port}/mcp", http_client=http_client)) as client,
129+
):
130+
transport.client_header_value = "expected_value"
131+
result = await client.call_tool("my_tool")
132+
assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=expected_value")])
133+
134+
135+
@pytest.mark.anyio
136+
async def test_streamable_http_mcpclient_to_httpx(server_port: int):
137+
with start_server_process("streamable_http", server_port):
138+
async with (
139+
ContextCapturingTransport() as transport,
140+
httpx.AsyncClient(transport=transport) as http_client,
141+
Client(streamable_http_client(f"http://127.0.0.1:{server_port}/mcp", http_client=http_client)) as client,
142+
):
143+
with set_test_contextvar("client_value_list"):
144+
await client.list_tools()
145+
assert transport.captured_context_var == snapshot("client_value_list")
146+
147+
with set_test_contextvar("client_value_call_tool"):
148+
await client.call_tool("my_tool")
149+
assert transport.captured_context_var == snapshot("client_value_call_tool")
150+
151+
152+
@pytest.mark.anyio
153+
async def test_sse_asgi_to_mcpserver(server_port: int):
154+
transport = SendClientHeaderTransport()
155+
156+
def client_factory(
157+
headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None
158+
) -> httpx.AsyncClient:
159+
return httpx.AsyncClient(transport=transport, headers=headers, timeout=timeout, auth=auth)
160+
161+
with start_server_process("sse", server_port):
162+
async with Client(
163+
sse_client(f"http://127.0.0.1:{server_port}/sse", httpx_client_factory=client_factory)
164+
) as client:
165+
transport.client_header_value = "expected_value"
166+
result = await client.call_tool("my_tool")
167+
assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=expected_value")])
168+
169+
170+
@pytest.mark.anyio
171+
async def test_sse_mcpclient_to_httpx(server_port: int):
172+
transport = ContextCapturingTransport()
173+
174+
def client_factory(
175+
headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None
176+
) -> httpx.AsyncClient:
177+
return httpx.AsyncClient(transport=transport, headers=headers, timeout=timeout, auth=auth)
178+
179+
with start_server_process("sse", server_port):
180+
async with Client(
181+
sse_client(f"http://127.0.0.1:{server_port}/sse", httpx_client_factory=client_factory)
182+
) as client:
183+
with set_test_contextvar("client_value_list"):
184+
await client.list_tools()
185+
assert transport.captured_context_var == snapshot("client_value_list")
186+
187+
with set_test_contextvar("client_value_call_tool"):
188+
await client.call_tool("my_tool")
189+
assert transport.captured_context_var == snapshot("client_value_call_tool")

0 commit comments

Comments
 (0)