Skip to content

Commit 0b2fd83

Browse files
committed
fix: refresh auth context for stateful HTTP requests
1 parent b33c811 commit 0b2fd83

File tree

2 files changed

+110
-4
lines changed

2 files changed

+110
-4
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def main():
4141
from collections.abc import AsyncIterator, Awaitable, Callable
4242
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
4343
from importlib.metadata import version as importlib_version
44-
from typing import Any, Generic
44+
from typing import Any, Generic, cast
4545

4646
import anyio
4747
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -52,8 +52,8 @@ async def main():
5252
from typing_extensions import TypeVar
5353

5454
from mcp import types
55-
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
56-
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
55+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, auth_context_var
56+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware
5757
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
5858
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
5959
from mcp.server.auth.settings import AuthSettings
@@ -471,7 +471,15 @@ async def _handle_request(
471471
close_sse_stream=close_sse_stream_cb,
472472
close_standalone_sse_stream=close_standalone_sse_stream_cb,
473473
)
474-
response = await handler(ctx, req.params)
474+
request_scope = cast(dict[str, object] | None, getattr(request_data, "scope", None))
475+
request_user = request_scope.get("user") if request_scope is not None else None
476+
auth_context_token = auth_context_var.set(
477+
request_user if isinstance(request_user, AuthenticatedUser) else None
478+
)
479+
try:
480+
response = await handler(ctx, req.params)
481+
finally:
482+
auth_context_var.reset(auth_context_token)
475483
except MCPError as err:
476484
response = err.error
477485
except anyio.get_cancelled_exc_class():
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Regression test for issue #2208.
2+
3+
In stateful streamable HTTP sessions, get_access_token() must reflect the
4+
Authorization header from the current request, not the one that created the
5+
session's background receive task.
6+
"""
7+
8+
import time
9+
10+
import httpx
11+
import pytest
12+
from pydantic import AnyHttpUrl
13+
14+
from mcp.client.session import ClientSession
15+
from mcp.client.streamable_http import streamable_http_client
16+
from mcp.server import Server, ServerRequestContext
17+
from mcp.server.auth.middleware.auth_context import get_access_token
18+
from mcp.server.auth.provider import AccessToken
19+
from mcp.server.auth.settings import AuthSettings
20+
from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent, Tool
21+
22+
23+
class EchoTokenVerifier:
24+
"""Accept any bearer token and expose it in the authenticated user."""
25+
26+
async def verify_token(self, token: str) -> AccessToken | None:
27+
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)
28+
29+
30+
class MutableBearerAuth(httpx.Auth):
31+
"""Update the bearer token between requests without rebuilding the client."""
32+
33+
def __init__(self, token: str) -> None:
34+
self.token = token
35+
36+
def auth_flow(self, request: httpx.Request):
37+
request.headers["Authorization"] = f"Bearer {self.token}"
38+
yield request
39+
40+
41+
async def handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
42+
access_token = get_access_token()
43+
token = access_token.token if access_token else "<none>"
44+
return CallToolResult(content=[TextContent(type="text", text=token)])
45+
46+
47+
async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
48+
return ListToolsResult(
49+
tools=[
50+
Tool(
51+
name="whoami",
52+
input_schema={"type": "object", "properties": {}},
53+
)
54+
]
55+
)
56+
57+
58+
@pytest.mark.anyio
59+
async def test_get_access_token_uses_current_request_in_stateful_streamable_http_session() -> None:
60+
server = Server(
61+
"auth-test-server",
62+
on_call_tool=handle_whoami,
63+
on_list_tools=handle_list_tools,
64+
)
65+
app = server.streamable_http_app(
66+
host="testserver",
67+
auth=AuthSettings(
68+
issuer_url=AnyHttpUrl("https://auth.example.com"),
69+
resource_server_url=AnyHttpUrl("https://testserver/mcp"),
70+
),
71+
token_verifier=EchoTokenVerifier(),
72+
)
73+
auth = MutableBearerAuth("token-A")
74+
75+
async with (
76+
app.router.lifespan_context(app),
77+
httpx.ASGITransport(app) as transport,
78+
httpx.AsyncClient(
79+
transport=transport,
80+
base_url="http://testserver",
81+
auth=auth,
82+
follow_redirects=True,
83+
timeout=httpx.Timeout(30.0, read=30.0),
84+
) as http_client,
85+
streamable_http_client("http://testserver/mcp", http_client=http_client) as (read_stream, write_stream),
86+
ClientSession(read_stream, write_stream) as session,
87+
):
88+
await session.initialize()
89+
90+
first_response = await session.call_tool("whoami", {})
91+
assert isinstance(first_response.content[0], TextContent)
92+
assert first_response.content[0].text == "token-A"
93+
94+
auth.token = "token-B"
95+
96+
second_response = await session.call_tool("whoami", {})
97+
assert isinstance(second_response.content[0], TextContent)
98+
assert second_response.content[0].text == "token-B"

0 commit comments

Comments
 (0)