Skip to content

Commit 5ea2360

Browse files
committed
fix(auth): make get_access_token per-request in stateful sessions
1 parent 5d82649 commit 5ea2360

3 files changed

Lines changed: 129 additions & 1 deletion

File tree

src/mcp/server/auth/middleware/auth_context.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import contextvars
22

3+
from contextvars import Token
4+
5+
from starlette.requests import Request
36
from starlette.types import ASGIApp, Receive, Scope, Send
47

58
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
@@ -20,6 +23,26 @@ def get_access_token() -> AccessToken | None:
2023
return auth_user.access_token if auth_user else None
2124

2225

26+
def _push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None:
27+
"""Set auth context for the current task from an incoming request.
28+
29+
This is primarily used by server transports where request handlers may run
30+
in background tasks that are not part of the original ASGI request task.
31+
"""
32+
if request is None:
33+
return None
34+
user = getattr(request, "user", None)
35+
if isinstance(user, AuthenticatedUser):
36+
return auth_context_var.set(user)
37+
return None
38+
39+
40+
def _pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None:
41+
if token is None:
42+
return
43+
auth_context_var.reset(token)
44+
45+
2346
class AuthContextMiddleware:
2447
"""Middleware that extracts the authenticated user from the request
2548
and sets it in a contextvar for easy access throughout the request lifecycle.

src/mcp/server/runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pydantic import BaseModel, ValidationError
2727
from typing_extensions import TypeVar
2828

29+
from mcp.server.auth.middleware.auth_context import _pop_auth_context, _push_auth_context_from_request
2930
from mcp.server.connection import Connection
3031
from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext
3132
from mcp.server.models import InitializationOptions
@@ -259,7 +260,11 @@ async def _inner() -> HandlerResult:
259260
return result
260261

261262
call = self._compose_server_middleware(ctx, method, params, _inner)
262-
result = _dump_result(await call())
263+
auth_token = _push_auth_context_from_request(ctx.request)
264+
try:
265+
result = _dump_result(await call())
266+
finally:
267+
_pop_auth_context(auth_token)
263268
if method == "initialize":
264269
# Commit only on chain success, so a middleware veto leaves no state.
265270
# Race-free: the read loop is parked until this call returns.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import time
2+
3+
import httpx
4+
import pytest
5+
from starlette.applications import Starlette
6+
from starlette.middleware import Middleware
7+
from starlette.middleware.authentication import AuthenticationMiddleware
8+
from starlette.routing import Mount
9+
10+
from mcp import Client
11+
from mcp.client.streamable_http import streamable_http_client
12+
from mcp.server import Server, ServerRequestContext
13+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
14+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
15+
from mcp.server.auth.provider import AccessToken
16+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
17+
from mcp.server.transport_security import TransportSecuritySettings
18+
from mcp.types import (
19+
CallToolRequestParams,
20+
CallToolResult,
21+
ListToolsResult,
22+
PaginatedRequestParams,
23+
TextContent,
24+
Tool,
25+
)
26+
27+
28+
class _EchoTokenVerifier:
29+
"""Accepts any bearer token and echoes it back as the verified AccessToken."""
30+
31+
async def verify_token(self, token: str) -> AccessToken | None:
32+
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)
33+
34+
35+
async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
36+
access = get_access_token()
37+
text = access.token if access else "<none>"
38+
return CallToolResult(content=[TextContent(type="text", text=text)])
39+
40+
41+
async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
42+
return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})])
43+
44+
45+
class _MutableBearerAuth(httpx.Auth):
46+
def __init__(self, token: str) -> None:
47+
self.token = token
48+
49+
def auth_flow(self, request: httpx.Request):
50+
request.headers["Authorization"] = f"Bearer {self.token}"
51+
yield request
52+
53+
54+
@pytest.mark.anyio
55+
async def test_get_access_token_reflects_current_request_in_stateful_session() -> None:
56+
host = "testserver"
57+
58+
server = Server(
59+
"auth-test-server",
60+
on_call_tool=_handle_whoami,
61+
on_list_tools=_handle_list_tools,
62+
)
63+
64+
security = TransportSecuritySettings(
65+
allowed_hosts=[host, f"{host}:*"],
66+
allowed_origins=[f"http://{host}:*"],
67+
)
68+
session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False)
69+
70+
asgi_app = Starlette(
71+
routes=[Mount("/mcp", app=session_manager.handle_request)],
72+
middleware=[
73+
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())),
74+
Middleware(AuthContextMiddleware),
75+
],
76+
lifespan=lambda app: session_manager.run(),
77+
)
78+
79+
auth = _MutableBearerAuth("token-A")
80+
async with asgi_app.router.lifespan_context(asgi_app):
81+
async with (
82+
httpx.ASGITransport(asgi_app) as transport,
83+
httpx.AsyncClient(
84+
transport=transport,
85+
base_url=f"http://{host}",
86+
auth=auth,
87+
timeout=httpx.Timeout(30, read=30),
88+
follow_redirects=True,
89+
) as http_client,
90+
):
91+
transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client)
92+
async with Client(transport_ctx) as client:
93+
r1 = await client.call_tool("whoami", {})
94+
assert isinstance(r1.content[0], TextContent)
95+
assert r1.content[0].text == "token-A"
96+
97+
auth.token = "token-B"
98+
r2 = await client.call_tool("whoami", {})
99+
assert isinstance(r2.content[0], TextContent)
100+
assert r2.content[0].text == "token-B"

0 commit comments

Comments
 (0)