-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_auth.py
More file actions
133 lines (116 loc) · 5.69 KB
/
test_auth.py
File metadata and controls
133 lines (116 loc) · 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Tests for OAuth 2.0 shared code."""
import json
from pydantic import AnyHttpUrl
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
def test_oauth():
"""Should not throw when parsing OAuth metadata."""
OAuthMetadata.model_validate(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/oauth2/authorize",
"token_endpoint": "https://example.com/oauth2/token",
"scopes_supported": ["read", "write"],
"response_types_supported": ["code", "token"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
}
)
def test_oidc():
"""Should not throw when parsing OIDC metadata."""
OAuthMetadata.model_validate(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/oauth2/authorize",
"token_endpoint": "https://example.com/oauth2/token",
"end_session_endpoint": "https://example.com/logout",
"id_token_signing_alg_values_supported": ["RS256"],
"jwks_uri": "https://example.com/.well-known/jwks.json",
"response_types_supported": ["code", "token"],
"revocation_endpoint": "https://example.com/oauth2/revoke",
"scopes_supported": ["openid", "read", "write"],
"subject_types_supported": ["public"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"userinfo_endpoint": "https://example.com/oauth2/userInfo",
}
)
def test_oauth_with_jarm():
"""Should not throw when parsing OAuth metadata that includes JARM response modes."""
OAuthMetadata.model_validate(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/oauth2/authorize",
"token_endpoint": "https://example.com/oauth2/token",
"scopes_supported": ["read", "write"],
"response_types_supported": ["code", "token"],
"response_modes_supported": [
"query",
"fragment",
"form_post",
"query.jwt",
"fragment.jwt",
"form_post.jwt",
"jwt",
],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
}
)
class TestIssuerTrailingSlash:
"""Tests for issue #1919: trailing slash in issuer URL.
RFC 8414 examples show issuer URLs without trailing slashes, and some
OAuth clients require exact match between discovery URL and returned issuer.
Pydantic's AnyHttpUrl automatically adds a trailing slash, so we strip it
during serialization.
"""
def test_oauth_metadata_issuer_no_trailing_slash_in_json(self):
"""Serialized issuer should not have trailing slash."""
metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://example.com"),
authorization_endpoint=AnyHttpUrl("https://example.com/oauth2/authorize"),
token_endpoint=AnyHttpUrl("https://example.com/oauth2/token"),
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["issuer"] == "https://example.com"
assert not serialized["issuer"].endswith("/")
def test_oauth_metadata_issuer_with_path_preserves_path(self):
"""Issuer with path should preserve the path, only strip trailing slash."""
metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://example.com/auth"),
authorization_endpoint=AnyHttpUrl("https://example.com/oauth2/authorize"),
token_endpoint=AnyHttpUrl("https://example.com/oauth2/token"),
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["issuer"] == "https://example.com/auth"
assert not serialized["issuer"].endswith("/")
def test_oauth_metadata_issuer_with_path_and_trailing_slash(self):
"""Issuer with path and trailing slash should only strip the trailing slash."""
metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://example.com/auth/"),
authorization_endpoint=AnyHttpUrl("https://example.com/oauth2/authorize"),
token_endpoint=AnyHttpUrl("https://example.com/oauth2/token"),
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["issuer"] == "https://example.com/auth"
def test_protected_resource_metadata_no_trailing_slash(self):
"""ProtectedResourceMetadata.resource should not have trailing slash."""
metadata = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://example.com"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["resource"] == "https://example.com"
assert not serialized["resource"].endswith("/")
def test_protected_resource_metadata_auth_servers_no_trailing_slash(self):
"""ProtectedResourceMetadata.authorization_servers should not have trailing slashes."""
metadata = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://example.com"),
authorization_servers=[
AnyHttpUrl("https://auth1.example.com"),
AnyHttpUrl("https://auth2.example.com/path"),
],
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["authorization_servers"] == [
"https://auth1.example.com",
"https://auth2.example.com/path",
]
for url in serialized["authorization_servers"]:
assert not url.endswith("/")