Skip to content

Commit 214cace

Browse files
committed
tidy: improved error handling in client_secret_basic
1 parent b54801d commit 214cace

1 file changed

Lines changed: 18 additions & 23 deletions

File tree

src/mcp/server/auth/handlers/token.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import hashlib
33
import time
44
from base64 import b64decode
5+
from binascii import Error as Base64Error
56
from dataclasses import dataclass
67
from typing import Annotated, Any, Literal
78

89
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
10+
from starlette.datastructures import Headers
911
from starlette.requests import Request
1012

1113
from mcp.server.auth.errors import stringify_pydantic_error
@@ -54,15 +56,17 @@ class BasicCredentials(BaseModel):
5456
client_secret: str
5557

5658
@classmethod
57-
def from_authorization(cls, authorization: str):
59+
def from_headers(cls, headers: Headers):
60+
if not (authorization := headers.get("Authorization")):
61+
raise AuthenticationError("Missing authorization header")
5862
try:
59-
if authorization.startswith("Basic "):
60-
[client_id, client_secret] = b64decode(authorization.removeprefix("Basic ")).decode().split(":", 1)
61-
return cls(client_id=client_id, client_secret=client_secret)
62-
except Exception:
63-
# TODO: better error here??
64-
return None
65-
return None
63+
scheme, credentials = authorization.split(None, 1)
64+
if scheme.lower() != "basic":
65+
raise AuthenticationError("Expected Basic authentication scheme")
66+
client_id, client_secret = b64decode(credentials).decode().split(":", 1)
67+
return cls(client_id=client_id, client_secret=client_secret)
68+
except ValueError | Base64Error | UnicodeDecodeError:
69+
raise AuthenticationError("Invalid Basic authentication credentials") from None
6670

6771

6872
Credentials = NoneCredentials | PostCredentials | BasicCredentials
@@ -120,30 +124,21 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
120124

121125
async def handle(self, request: Request):
122126
try:
123-
form_data = await request.form()
124-
token_request = TokenRequest.model_validate(dict(form_data)).root
127+
form_data = dict(await request.form())
128+
token_request = TokenRequest.model_validate(form_data).root
125129
try:
126-
credentials = FormCredentials.model_validate(dict(form_data)).root
130+
credentials = FormCredentials.model_validate(form_data).root
127131
except ValidationError:
128-
credentials = (
129-
BasicCredentials.from_authorization(authorization)
130-
if (authorization := request.headers.get("Authorization"))
131-
else None
132-
)
133-
if not credentials:
134-
return self.response(
135-
TokenErrorResponse(
136-
error="invalid_request",
137-
error_description="missing credentials",
138-
)
139-
)
132+
credentials = BasicCredentials.from_headers(request.headers)
140133
except ValidationError as validation_error:
141134
return self.response(
142135
TokenErrorResponse(
143136
error="invalid_request",
144137
error_description=stringify_pydantic_error(validation_error),
145138
)
146139
)
140+
except AuthenticationError as auth_error:
141+
return self.response(TokenErrorResponse(error="invalid_request", error_description=auth_error.message))
147142
try:
148143
client_info = await self.client_authenticator.authenticate(
149144
client_id=credentials.client_id,

0 commit comments

Comments
 (0)