Skip to content

Commit ae8d1c1

Browse files
committed
Added regression test
1 parent bac2789 commit ae8d1c1

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

tests/shared/test_streamable_http.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from httpx_sse import ServerSentEvent
2222
from pydantic import AnyUrl
2323
from starlette.applications import Starlette
24+
from starlette.middleware import Middleware
25+
from starlette.middleware.authentication import AuthenticationMiddleware
2426
from starlette.requests import Request
2527
from starlette.routing import Mount
2628

@@ -32,6 +34,9 @@
3234
streamablehttp_client, # pyright: ignore[reportDeprecated]
3335
)
3436
from mcp.server import Server
37+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
38+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
39+
from mcp.server.auth.provider import AccessToken, TokenVerifier
3540
from mcp.server.streamable_http import (
3641
MCP_PROTOCOL_VERSION_HEADER,
3742
MCP_SESSION_ID_HEADER,
@@ -1520,6 +1525,71 @@ def run_context_aware_server(port: int): # pragma: no cover
15201525
server_instance.run()
15211526

15221527

1528+
class AuthTokenServerTest(Server): # pragma: no cover
1529+
def __init__(self):
1530+
super().__init__("AuthTokenServer")
1531+
1532+
@self.list_tools()
1533+
async def handle_list_tools() -> list[Tool]:
1534+
return [
1535+
Tool(
1536+
name="echo_access_token",
1537+
description="Return the current access token",
1538+
inputSchema={"type": "object", "properties": {}},
1539+
)
1540+
]
1541+
1542+
@self.call_tool()
1543+
async def handle_call_tool(name: str, _args: dict[str, Any]) -> list[TextContent]:
1544+
assert name == "echo_access_token"
1545+
access_token = get_access_token()
1546+
assert access_token is not None
1547+
return [TextContent(type="text", text=access_token.token)]
1548+
1549+
1550+
def run_auth_token_server(port: int) -> None: # pragma: no cover
1551+
"""Run the auth token test server."""
1552+
server = AuthTokenServerTest()
1553+
1554+
class AcceptAllTokenVerifier(TokenVerifier):
1555+
async def verify_token(self, token: str) -> AccessToken | None:
1556+
return AccessToken(
1557+
token=token,
1558+
client_id="test-client",
1559+
scopes=["test"],
1560+
)
1561+
1562+
token_verifier = AcceptAllTokenVerifier()
1563+
1564+
session_manager = StreamableHTTPSessionManager(
1565+
app=server,
1566+
event_store=None,
1567+
json_response=False,
1568+
)
1569+
1570+
middleware = [
1571+
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)),
1572+
Middleware(AuthContextMiddleware),
1573+
]
1574+
1575+
app = Starlette(
1576+
debug=True,
1577+
routes=[Mount("/mcp", app=session_manager.handle_request)],
1578+
middleware=middleware,
1579+
lifespan=lambda app: session_manager.run(),
1580+
)
1581+
1582+
server_instance = uvicorn.Server(
1583+
config=uvicorn.Config(
1584+
app=app,
1585+
host="127.0.0.1",
1586+
port=port,
1587+
log_level="error",
1588+
)
1589+
)
1590+
server_instance.run()
1591+
1592+
15231593
@pytest.fixture
15241594
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
15251595
"""Start the context-aware server in a separate process."""
@@ -1537,6 +1607,22 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
15371607
print("Context-aware server process failed to terminate")
15381608

15391609

1610+
@pytest.fixture
1611+
def auth_token_server(basic_server_port: int) -> Generator[None, None, None]:
1612+
"""Start the auth token server in a separate process."""
1613+
proc = multiprocessing.Process(target=run_auth_token_server, args=(basic_server_port,), daemon=True)
1614+
proc.start()
1615+
1616+
wait_for_server(basic_server_port)
1617+
1618+
yield
1619+
1620+
proc.kill()
1621+
proc.join(timeout=2)
1622+
if proc.is_alive(): # pragma: no cover
1623+
print("Auth token server process failed to terminate")
1624+
1625+
15401626
@pytest.mark.anyio
15411627
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
15421628
"""Test that request context is properly propagated through StreamableHTTP."""
@@ -1571,6 +1657,34 @@ async def test_streamablehttp_request_context_propagation(context_aware_server:
15711657
assert headers_data.get("x-trace-id") == "trace-123"
15721658

15731659

1660+
@pytest.mark.anyio
1661+
async def test_streamablehttp_refreshes_access_token(auth_token_server: None, basic_server_url: str) -> None:
1662+
"""Ensure refreshed bearer tokens are used for subsequent requests."""
1663+
token_a = "token-a"
1664+
token_b = "token-b"
1665+
1666+
async with create_mcp_http_client(headers={"Authorization": f"Bearer {token_a}"}) as httpx_client:
1667+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as (
1668+
read_stream,
1669+
write_stream,
1670+
_,
1671+
):
1672+
async with ClientSession(read_stream, write_stream) as session:
1673+
result = await session.initialize()
1674+
assert isinstance(result, InitializeResult)
1675+
1676+
tool_result = await session.call_tool("echo_access_token", {})
1677+
assert len(tool_result.content) == 1
1678+
assert isinstance(tool_result.content[0], TextContent)
1679+
assert tool_result.content[0].text == token_a
1680+
1681+
httpx_client.headers["Authorization"] = f"Bearer {token_b}"
1682+
tool_result = await session.call_tool("echo_access_token", {})
1683+
assert len(tool_result.content) == 1
1684+
assert isinstance(tool_result.content[0], TextContent)
1685+
assert tool_result.content[0].text == token_b
1686+
1687+
15741688
@pytest.mark.anyio
15751689
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
15761690
"""Test that request contexts are isolated between StreamableHTTP clients."""

0 commit comments

Comments
 (0)