Skip to content

Commit 5db2ced

Browse files
committed
feat(oauth): support client_secret_basic
1 parent 0bcecff commit 5db2ced

4 files changed

Lines changed: 115 additions & 22 deletions

File tree

src/mcp/client/auth.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import secrets
1111
import string
1212
import time
13+
from base64 import b64encode
1314
from collections.abc import AsyncGenerator, Awaitable, Callable
1415
from typing import Protocol
1516
from urllib.parse import urlencode, urljoin
@@ -359,22 +360,35 @@ async def _exchange_code_for_token(self, auth_code: str, client_info: OAuthClien
359360
auth_base_url = self._get_authorization_base_url(self.server_url)
360361
token_url = urljoin(auth_base_url, "/token")
361362

363+
extra_headers = {}
364+
362365
token_data = {
363366
"grant_type": "authorization_code",
364367
"code": auth_code,
365368
"redirect_uri": str(self.client_metadata.redirect_uris[0]),
366-
"client_id": client_info.client_id,
367369
"code_verifier": self._code_verifier,
368370
}
369371

370-
if client_info.client_secret:
371-
token_data["client_secret"] = client_info.client_secret
372+
match client_info.token_endpoint_auth_method:
373+
case "none":
374+
token_data["client_id"] = client_info.client_id
375+
case "client_secret_post" if client_info.client_secret:
376+
token_data["client_id"] = client_info.client_id
377+
token_data["client_secret"] = client_info.client_secret
378+
case "client_secret_basic" if client_info.client_secret:
379+
basic = b64encode(
380+
f"{client_info.client_id}:{client_info.client_secret}".encode()
381+
).decode()
382+
extra_headers = {"Authorization": f"Basic {basic}"}
383+
case _:
384+
pass
372385

373386
async with httpx.AsyncClient() as client:
374387
response = await client.post(
375388
token_url,
376389
data=token_data,
377-
headers={"Content-Type": "application/x-www-form-urlencoded"},
390+
headers={"Content-Type": "application/x-www-form-urlencoded"}
391+
| extra_headers,
378392
timeout=30.0,
379393
)
380394

@@ -419,21 +433,35 @@ async def _refresh_access_token(self) -> bool:
419433
auth_base_url = self._get_authorization_base_url(self.server_url)
420434
token_url = urljoin(auth_base_url, "/token")
421435

436+
extra_headers = {}
437+
422438
refresh_data = {
423439
"grant_type": "refresh_token",
424440
"refresh_token": self._current_tokens.refresh_token,
425441
"client_id": client_info.client_id,
426442
}
427443

428-
if client_info.client_secret:
429-
refresh_data["client_secret"] = client_info.client_secret
444+
match client_info.token_endpoint_auth_method:
445+
case "none":
446+
refresh_data["client_id"] = client_info.client_id
447+
case "client_secret_post" if client_info.client_secret:
448+
refresh_data["client_id"] = client_info.client_id
449+
refresh_data["client_secret"] = client_info.client_secret
450+
case "client_secret_basic" if client_info.client_secret:
451+
basic = b64encode(
452+
f"{client_info.client_id}:{client_info.client_secret}".encode()
453+
).decode()
454+
extra_headers = {"Authorization": f"Basic {basic}"}
455+
case _:
456+
pass
430457

431458
try:
432459
async with httpx.AsyncClient() as client:
433460
response = await client.post(
434461
token_url,
435462
data=refresh_data,
436-
headers={"Content-Type": "application/x-www-form-urlencoded"},
463+
headers={"Content-Type": "application/x-www-form-urlencoded"}
464+
| extra_headers,
437465
timeout=30.0,
438466
)
439467

src/mcp/server/auth/handlers/token.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import hashlib
33
import time
4+
from base64 import b64decode
45
from dataclasses import dataclass
56
from typing import Annotated, Any, Literal
67

@@ -19,9 +20,6 @@ class AuthorizationCodeRequest(BaseModel):
1920
grant_type: Literal["authorization_code"]
2021
code: str = Field(..., description="The authorization code")
2122
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
22-
client_id: str
23-
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
24-
client_secret: str | None = None
2523
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
2624
code_verifier: str = Field(..., description="PKCE code verifier")
2725

@@ -31,9 +29,50 @@ class RefreshTokenRequest(BaseModel):
3129
grant_type: Literal["refresh_token"]
3230
refresh_token: str = Field(..., description="The refresh token")
3331
scope: str | None = Field(None, description="Optional scope parameter")
32+
33+
34+
class NoneCredentials(BaseModel):
35+
client_id: str
36+
client_secret: None = None
37+
38+
39+
class PostCredentials(BaseModel):
3440
client_id: str
3541
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
36-
client_secret: str | None = None
42+
client_secret: str
43+
44+
45+
class FormCredentials(
46+
RootModel[
47+
Annotated[
48+
NoneCredentials | PostCredentials,
49+
Field(discriminator="client_secret"),
50+
]
51+
]
52+
):
53+
root: Annotated[
54+
NoneCredentials | PostCredentials,
55+
Field(discriminator="client_secret"),
56+
]
57+
58+
59+
class BasicCredentials(BaseModel):
60+
client_id: str
61+
client_secret: str
62+
63+
@classmethod
64+
def from_authorization(cls, authorization: str):
65+
try:
66+
if authorization.startswith("Basic "):
67+
[client_id, client_secret] = b64decode(authorization.removeprefix("Basic ")).decode().split(":", 1)
68+
return cls(client_id=client_id, client_secret=client_secret)
69+
except Exception:
70+
# TODO: better error here??
71+
return None
72+
return None
73+
74+
75+
Credentials = NoneCredentials | PostCredentials | BasicCredentials
3776

3877

3978
class TokenRequest(
@@ -90,19 +129,42 @@ async def handle(self, request: Request):
90129
try:
91130
form_data = await request.form()
92131
token_request = TokenRequest.model_validate(dict(form_data)).root
132+
try:
133+
credentials = FormCredentials.model_validate(dict(form_data)).root
134+
except ValidationError:
135+
credentials = (
136+
BasicCredentials.from_authorization(authorization)
137+
if (authorization := request.headers.get("Authorization"))
138+
else None
139+
)
140+
if not credentials:
141+
return self.response(
142+
TokenErrorResponse(
143+
error="invalid_request",
144+
error_description="missing credentials",
145+
)
146+
)
93147
except ValidationError as validation_error:
94148
return self.response(
95149
TokenErrorResponse(
96150
error="invalid_request",
97151
error_description=stringify_pydantic_error(validation_error),
98152
)
99153
)
100-
101154
try:
102155
client_info = await self.client_authenticator.authenticate(
103-
client_id=token_request.client_id,
104-
client_secret=token_request.client_secret,
156+
client_id=credentials.client_id,
157+
client_secret=credentials.client_secret,
105158
)
159+
match client_info.token_endpoint_auth_method:
160+
case "none" if not isinstance(credentials, NoneCredentials):
161+
raise AuthenticationError("Invalid credentials for client token_endpoint_auth_method")
162+
case "client_secret_post" if not isinstance(credentials, PostCredentials):
163+
raise AuthenticationError("Invalid credentials for client token_endpoint_auth_method")
164+
case "client_secret_basic" if not isinstance(credentials, BasicCredentials):
165+
raise AuthenticationError("Invalid credentials for client token_endpoint_auth_method")
166+
case _:
167+
pass
106168
except AuthenticationError as e:
107169
return self.response(
108170
TokenErrorResponse(
@@ -126,7 +188,7 @@ async def handle(self, request: Request):
126188
match token_request:
127189
case AuthorizationCodeRequest():
128190
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
129-
if auth_code is None or auth_code.client_id != token_request.client_id:
191+
if auth_code is None or auth_code.client_id != credentials.client_id:
130192
# if code belongs to different client, pretend it doesn't exist
131193
return self.response(
132194
TokenErrorResponse(
@@ -185,7 +247,7 @@ async def handle(self, request: Request):
185247

186248
case RefreshTokenRequest():
187249
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
188-
if refresh_token is None or refresh_token.client_id != token_request.client_id:
250+
if refresh_token is None or refresh_token.client_id != credentials.client_id:
189251
# if token belongs to different client, pretend it doesn't exist
190252
return self.response(
191253
TokenErrorResponse(

src/mcp/server/auth/routes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def build_metadata(
159159
response_types_supported=["code"],
160160
response_modes_supported=None,
161161
grant_types_supported=["authorization_code", "refresh_token"],
162-
token_endpoint_auth_methods_supported=["client_secret_post"],
162+
token_endpoint_auth_methods_supported=[
163+
"client_secret_post",
164+
"client_secret_basic",
165+
],
163166
token_endpoint_auth_signing_alg_values_supported=None,
164167
service_documentation=service_documentation_url,
165168
ui_locales_supported=None,
@@ -176,6 +179,9 @@ def build_metadata(
176179
# Add revocation endpoint if supported
177180
if revocation_options.enabled:
178181
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
179-
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
182+
metadata.revocation_endpoint_auth_methods_supported = [
183+
"client_secret_post",
184+
"client_secret_basic",
185+
]
180186

181187
return metadata

src/mcp/shared/auth.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ class OAuthClientMetadata(BaseModel):
4242
"""
4343

4444
redirect_uris: list[AnyUrl] = Field(..., min_length=1)
45-
# token_endpoint_auth_method: this implementation only supports none &
46-
# client_secret_post;
47-
# ie: we do not support client_secret_basic
48-
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
45+
token_endpoint_auth_method: Literal["none", "client_secret_post", "client_secret_basic"] = "client_secret_post"
4946
# grant_types: this implementation only supports authorization_code & refresh_token
5047
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
5148
"authorization_code",

0 commit comments

Comments
 (0)