Skip to content

Commit 00174fc

Browse files
committed
fix auth context reset for streamable HTTP
1 parent 190f101 commit 00174fc

2 files changed

Lines changed: 27 additions & 32 deletions

File tree

src/mcp/server/auth/middleware/auth_context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def push_auth_context_from_request(request: Request | None) -> Token[Authenticat
3838
user = getattr(request, "user", None)
3939
except AssertionError:
4040
user = None
41-
if isinstance(user, AuthenticatedUser):
42-
return auth_context_var.set(user)
43-
return None
41+
return auth_context_var.set(user if isinstance(user, AuthenticatedUser) else None)
4442

4543

4644
def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None:

tests/server/auth/test_get_access_token_streamable_http.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
1515
from mcp.server.auth.provider import AccessToken
1616
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
17-
from mcp.server.transport_security import TransportSecuritySettings
1817
from mcp.types import (
1918
CallToolRequestParams,
2019
CallToolResult,
@@ -43,14 +42,34 @@ async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequest
4342

4443

4544
class _MutableBearerAuth(httpx.Auth):
46-
def __init__(self, token: str) -> None:
45+
def __init__(self, token: str | None) -> None:
4746
self.token = token
4847

4948
def auth_flow(self, request: httpx.Request):
50-
request.headers["Authorization"] = f"Bearer {self.token}"
49+
if self.token is not None:
50+
request.headers["Authorization"] = f"Bearer {self.token}"
5151
yield request
5252

5353

54+
async def _call_whoami(asgi_app: Starlette, host: str, token: str | None) -> str:
55+
auth = _MutableBearerAuth(token)
56+
async with (
57+
httpx.ASGITransport(asgi_app) as transport,
58+
httpx.AsyncClient(
59+
transport=transport,
60+
base_url=f"http://{host}",
61+
auth=auth,
62+
timeout=httpx.Timeout(30, read=30),
63+
follow_redirects=True,
64+
) as http_client,
65+
):
66+
transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client)
67+
async with Client(transport_ctx) as client: # pragma: no branch
68+
result = await client.call_tool("whoami", {})
69+
assert isinstance(result.content[0], TextContent)
70+
return result.content[0].text
71+
72+
5473
@pytest.mark.anyio
5574
async def test_get_access_token_reflects_current_request_in_stateful_session() -> None:
5675
host = "testserver"
@@ -61,11 +80,7 @@ async def test_get_access_token_reflects_current_request_in_stateful_session() -
6180
on_list_tools=_handle_list_tools,
6281
)
6382

64-
security = TransportSecuritySettings(
65-
allowed_hosts=[host, f"{host}:*"],
66-
allowed_origins=[f"http://{host}:*"],
67-
)
68-
session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False)
83+
session_manager = StreamableHTTPSessionManager(app=server, stateless=False)
6984

7085
asgi_app = Starlette(
7186
routes=[Mount("/mcp", app=session_manager.handle_request)],
@@ -76,25 +91,7 @@ async def test_get_access_token_reflects_current_request_in_stateful_session() -
7691
lifespan=lambda app: session_manager.run(),
7792
)
7893

79-
auth = _MutableBearerAuth("token-A")
8094
async with asgi_app.router.lifespan_context(asgi_app):
81-
async with (
82-
httpx.ASGITransport(asgi_app) as transport,
83-
httpx.AsyncClient(
84-
transport=transport,
85-
base_url=f"http://{host}",
86-
auth=auth,
87-
timeout=httpx.Timeout(30, read=30),
88-
follow_redirects=True,
89-
) as http_client,
90-
):
91-
transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client)
92-
async with Client(transport_ctx) as client:
93-
r1 = await client.call_tool("whoami", {})
94-
assert isinstance(r1.content[0], TextContent)
95-
assert r1.content[0].text == "token-A"
96-
97-
auth.token = "token-B"
98-
r2 = await client.call_tool("whoami", {})
99-
assert isinstance(r2.content[0], TextContent)
100-
assert r2.content[0].text == "token-B"
95+
assert await _call_whoami(asgi_app, host, "token-A") == "token-A"
96+
assert await _call_whoami(asgi_app, host, "token-B") == "token-B"
97+
assert await _call_whoami(asgi_app, host, None) == "<none>"

0 commit comments

Comments
 (0)