forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_auth_context.py
More file actions
141 lines (104 loc) · 4.16 KB
/
test_auth_context.py
File metadata and controls
141 lines (104 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Tests for the AuthContext middleware components."""
import time
import pytest
from starlette.types import Message, Receive, Scope, Send
from mcp.server.auth.middleware.auth_context import (
AuthContextMiddleware,
auth_context_var,
get_access_token,
)
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
class MockApp:
"""Mock ASGI app for testing."""
def __init__(self):
self.called = False
self.scope: Scope | None = None
self.receive: Receive | None = None
self.send: Send | None = None
self.access_token_during_call: AccessToken | None = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.called = True
self.scope = scope
self.receive = receive
self.send = send
# Check the context during the call
self.access_token_during_call = get_access_token()
@pytest.fixture
def valid_access_token() -> AccessToken:
"""Create a valid access token."""
return AccessToken(
token="valid_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=int(time.time()) + 3600, # 1 hour from now
subject="user_123",
)
@pytest.mark.anyio
async def test_auth_context_middleware_with_authenticated_user(valid_access_token: AccessToken):
"""Test middleware with an authenticated user in scope."""
app = MockApp()
middleware = AuthContextMiddleware(app)
# Create an authenticated user
user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user}
# Create dummy async functions for receive and send
async def receive() -> Message: # pragma: no cover
return {"type": "http.request"}
async def send(message: Message) -> None: # pragma: no cover
pass
# Verify context is empty before middleware
assert auth_context_var.get() is None
assert get_access_token() is None
# Run the middleware
await middleware(scope, receive, send)
# Verify the app was called
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
# Verify the access token was available during the call
assert app.access_token_during_call == valid_access_token
# Verify context is reset after middleware
assert auth_context_var.get() is None
assert get_access_token() is None
@pytest.mark.anyio
async def test_auth_context_middleware_subject_preserved(valid_access_token: AccessToken):
"""Test that subject field on AccessToken is available via get_access_token()."""
app = MockApp()
middleware = AuthContextMiddleware(app)
user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user}
async def receive() -> Message: # pragma: no cover
return {"type": "http.request"}
async def send(message: Message) -> None: # pragma: no cover
pass
await middleware(scope, receive, send)
assert app.access_token_during_call is not None
assert app.access_token_during_call.subject == "user_123"
@pytest.mark.anyio
async def test_auth_context_middleware_with_no_user():
"""Test middleware with no user in scope."""
app = MockApp()
middleware = AuthContextMiddleware(app)
scope: Scope = {"type": "http"} # No user
# Create dummy async functions for receive and send
async def receive() -> Message: # pragma: no cover
return {"type": "http.request"}
async def send(message: Message) -> None: # pragma: no cover
pass
# Verify context is empty before middleware
assert auth_context_var.get() is None
assert get_access_token() is None
# Run the middleware
await middleware(scope, receive, send)
# Verify the app was called
assert app.called
assert app.scope == scope
assert app.receive == receive
assert app.send == send
# Verify the access token was not available during the call
assert app.access_token_during_call is None
# Verify context is still empty after middleware
assert auth_context_var.get() is None
assert get_access_token() is None