-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Expand file tree
/
Copy pathtest_get_access_token_streamable_http.py
More file actions
97 lines (77 loc) · 3.5 KB
/
test_get_access_token_streamable_http.py
File metadata and controls
97 lines (77 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import time
import httpx
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.routing import Mount
from mcp import Client
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server, ServerRequestContext
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
from mcp.server.auth.provider import AccessToken
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.types import (
CallToolRequestParams,
CallToolResult,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Tool,
)
class _EchoTokenVerifier:
"""Accepts any bearer token and echoes it back as the verified AccessToken."""
async def verify_token(self, token: str) -> AccessToken | None:
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)
async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
access = get_access_token()
text = access.token if access else "<none>"
return CallToolResult(content=[TextContent(type="text", text=text)])
async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})])
class _MutableBearerAuth(httpx.Auth):
def __init__(self, token: str | None) -> None:
self.token = token
def auth_flow(self, request: httpx.Request):
if self.token is not None:
request.headers["Authorization"] = f"Bearer {self.token}"
yield request
async def _call_whoami(asgi_app: Starlette, host: str, token: str | None) -> str:
auth = _MutableBearerAuth(token)
async with (
httpx.ASGITransport(asgi_app) as transport,
httpx.AsyncClient(
transport=transport,
base_url=f"http://{host}",
auth=auth,
timeout=httpx.Timeout(30, read=30),
follow_redirects=True,
) as http_client,
):
transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client)
async with Client(transport_ctx) as client: # pragma: no branch
result = await client.call_tool("whoami", {})
assert isinstance(result.content[0], TextContent)
return result.content[0].text
@pytest.mark.anyio
async def test_get_access_token_reflects_current_request_in_stateful_session() -> None:
host = "testserver"
server = Server(
"auth-test-server",
on_call_tool=_handle_whoami,
on_list_tools=_handle_list_tools,
)
session_manager = StreamableHTTPSessionManager(app=server, stateless=False)
asgi_app = Starlette(
routes=[Mount("/mcp", app=session_manager.handle_request)],
middleware=[
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())),
Middleware(AuthContextMiddleware),
],
lifespan=lambda app: session_manager.run(),
)
async with asgi_app.router.lifespan_context(asgi_app):
assert await _call_whoami(asgi_app, host, "token-A") == "token-A"
assert await _call_whoami(asgi_app, host, "token-B") == "token-B"
assert await _call_whoami(asgi_app, host, None) == "<none>"