2121from httpx_sse import ServerSentEvent
2222from pydantic import AnyUrl
2323from starlette .applications import Starlette
24+ from starlette .middleware import Middleware
25+ from starlette .middleware .authentication import AuthenticationMiddleware
2426from starlette .requests import Request
2527from starlette .routing import Mount
2628
3234 streamablehttp_client , # pyright: ignore[reportDeprecated]
3335)
3436from mcp .server import Server
37+ from mcp .server .auth .middleware .auth_context import AuthContextMiddleware , get_access_token
38+ from mcp .server .auth .middleware .bearer_auth import BearerAuthBackend
39+ from mcp .server .auth .provider import AccessToken , TokenVerifier
3540from mcp .server .streamable_http import (
3641 MCP_PROTOCOL_VERSION_HEADER ,
3742 MCP_SESSION_ID_HEADER ,
@@ -1520,6 +1525,71 @@ def run_context_aware_server(port: int): # pragma: no cover
15201525 server_instance .run ()
15211526
15221527
1528+ class AuthTokenServerTest (Server ): # pragma: no cover
1529+ def __init__ (self ):
1530+ super ().__init__ ("AuthTokenServer" )
1531+
1532+ @self .list_tools ()
1533+ async def handle_list_tools () -> list [Tool ]:
1534+ return [
1535+ Tool (
1536+ name = "echo_access_token" ,
1537+ description = "Return the current access token" ,
1538+ inputSchema = {"type" : "object" , "properties" : {}},
1539+ )
1540+ ]
1541+
1542+ @self .call_tool ()
1543+ async def handle_call_tool (name : str , _args : dict [str , Any ]) -> list [TextContent ]:
1544+ assert name == "echo_access_token"
1545+ access_token = get_access_token ()
1546+ assert access_token is not None
1547+ return [TextContent (type = "text" , text = access_token .token )]
1548+
1549+
1550+ def run_auth_token_server (port : int ) -> None : # pragma: no cover
1551+ """Run the auth token test server."""
1552+ server = AuthTokenServerTest ()
1553+
1554+ class AcceptAllTokenVerifier (TokenVerifier ):
1555+ async def verify_token (self , token : str ) -> AccessToken | None :
1556+ return AccessToken (
1557+ token = token ,
1558+ client_id = "test-client" ,
1559+ scopes = ["test" ],
1560+ )
1561+
1562+ token_verifier = AcceptAllTokenVerifier ()
1563+
1564+ session_manager = StreamableHTTPSessionManager (
1565+ app = server ,
1566+ event_store = None ,
1567+ json_response = False ,
1568+ )
1569+
1570+ middleware = [
1571+ Middleware (AuthenticationMiddleware , backend = BearerAuthBackend (token_verifier )),
1572+ Middleware (AuthContextMiddleware ),
1573+ ]
1574+
1575+ app = Starlette (
1576+ debug = True ,
1577+ routes = [Mount ("/mcp" , app = session_manager .handle_request )],
1578+ middleware = middleware ,
1579+ lifespan = lambda app : session_manager .run (),
1580+ )
1581+
1582+ server_instance = uvicorn .Server (
1583+ config = uvicorn .Config (
1584+ app = app ,
1585+ host = "127.0.0.1" ,
1586+ port = port ,
1587+ log_level = "error" ,
1588+ )
1589+ )
1590+ server_instance .run ()
1591+
1592+
15231593@pytest .fixture
15241594def context_aware_server (basic_server_port : int ) -> Generator [None , None , None ]:
15251595 """Start the context-aware server in a separate process."""
@@ -1537,6 +1607,22 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
15371607 print ("Context-aware server process failed to terminate" )
15381608
15391609
1610+ @pytest .fixture
1611+ def auth_token_server (basic_server_port : int ) -> Generator [None , None , None ]:
1612+ """Start the auth token server in a separate process."""
1613+ proc = multiprocessing .Process (target = run_auth_token_server , args = (basic_server_port ,), daemon = True )
1614+ proc .start ()
1615+
1616+ wait_for_server (basic_server_port )
1617+
1618+ yield
1619+
1620+ proc .kill ()
1621+ proc .join (timeout = 2 )
1622+ if proc .is_alive (): # pragma: no cover
1623+ print ("Auth token server process failed to terminate" )
1624+
1625+
15401626@pytest .mark .anyio
15411627async def test_streamablehttp_request_context_propagation (context_aware_server : None , basic_server_url : str ) -> None :
15421628 """Test that request context is properly propagated through StreamableHTTP."""
@@ -1571,6 +1657,34 @@ async def test_streamablehttp_request_context_propagation(context_aware_server:
15711657 assert headers_data .get ("x-trace-id" ) == "trace-123"
15721658
15731659
1660+ @pytest .mark .anyio
1661+ async def test_streamablehttp_refreshes_access_token (auth_token_server : None , basic_server_url : str ) -> None :
1662+ """Ensure refreshed bearer tokens are used for subsequent requests."""
1663+ token_a = "token-a"
1664+ token_b = "token-b"
1665+
1666+ async with create_mcp_http_client (headers = {"Authorization" : f"Bearer { token_a } " }) as httpx_client :
1667+ async with streamable_http_client (f"{ basic_server_url } /mcp" , http_client = httpx_client ) as (
1668+ read_stream ,
1669+ write_stream ,
1670+ _ ,
1671+ ):
1672+ async with ClientSession (read_stream , write_stream ) as session :
1673+ result = await session .initialize ()
1674+ assert isinstance (result , InitializeResult )
1675+
1676+ tool_result = await session .call_tool ("echo_access_token" , {})
1677+ assert len (tool_result .content ) == 1
1678+ assert isinstance (tool_result .content [0 ], TextContent )
1679+ assert tool_result .content [0 ].text == token_a
1680+
1681+ httpx_client .headers ["Authorization" ] = f"Bearer { token_b } "
1682+ tool_result = await session .call_tool ("echo_access_token" , {})
1683+ assert len (tool_result .content ) == 1
1684+ assert isinstance (tool_result .content [0 ], TextContent )
1685+ assert tool_result .content [0 ].text == token_b
1686+
1687+
15741688@pytest .mark .anyio
15751689async def test_streamablehttp_request_context_isolation (context_aware_server : None , basic_server_url : str ) -> None :
15761690 """Test that request contexts are isolated between StreamableHTTP clients."""
0 commit comments