|
2 | 2 | import hashlib |
3 | 3 | import time |
4 | 4 | from base64 import b64decode |
| 5 | +from binascii import Error as Base64Error |
5 | 6 | from dataclasses import dataclass |
6 | 7 | from typing import Annotated, Any, Literal |
7 | 8 |
|
8 | 9 | from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError |
| 10 | +from starlette.datastructures import Headers |
9 | 11 | from starlette.requests import Request |
10 | 12 |
|
11 | 13 | from mcp.server.auth.errors import stringify_pydantic_error |
@@ -54,15 +56,17 @@ class BasicCredentials(BaseModel): |
54 | 56 | client_secret: str |
55 | 57 |
|
56 | 58 | @classmethod |
57 | | - def from_authorization(cls, authorization: str): |
| 59 | + def from_headers(cls, headers: Headers): |
| 60 | + if not (authorization := headers.get("Authorization")): |
| 61 | + raise AuthenticationError("Missing authorization header") |
58 | 62 | try: |
59 | | - if authorization.startswith("Basic "): |
60 | | - [client_id, client_secret] = b64decode(authorization.removeprefix("Basic ")).decode().split(":", 1) |
61 | | - return cls(client_id=client_id, client_secret=client_secret) |
62 | | - except Exception: |
63 | | - # TODO: better error here?? |
64 | | - return None |
65 | | - return None |
| 63 | + scheme, credentials = authorization.split(None, 1) |
| 64 | + if scheme.lower() != "basic": |
| 65 | + raise AuthenticationError("Expected Basic authentication scheme") |
| 66 | + client_id, client_secret = b64decode(credentials).decode().split(":", 1) |
| 67 | + return cls(client_id=client_id, client_secret=client_secret) |
| 68 | + except ValueError | Base64Error | UnicodeDecodeError: |
| 69 | + raise AuthenticationError("Invalid Basic authentication credentials") from None |
66 | 70 |
|
67 | 71 |
|
68 | 72 | Credentials = NoneCredentials | PostCredentials | BasicCredentials |
@@ -120,30 +124,21 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse): |
120 | 124 |
|
121 | 125 | async def handle(self, request: Request): |
122 | 126 | try: |
123 | | - form_data = await request.form() |
124 | | - token_request = TokenRequest.model_validate(dict(form_data)).root |
| 127 | + form_data = dict(await request.form()) |
| 128 | + token_request = TokenRequest.model_validate(form_data).root |
125 | 129 | try: |
126 | | - credentials = FormCredentials.model_validate(dict(form_data)).root |
| 130 | + credentials = FormCredentials.model_validate(form_data).root |
127 | 131 | except ValidationError: |
128 | | - credentials = ( |
129 | | - BasicCredentials.from_authorization(authorization) |
130 | | - if (authorization := request.headers.get("Authorization")) |
131 | | - else None |
132 | | - ) |
133 | | - if not credentials: |
134 | | - return self.response( |
135 | | - TokenErrorResponse( |
136 | | - error="invalid_request", |
137 | | - error_description="missing credentials", |
138 | | - ) |
139 | | - ) |
| 132 | + credentials = BasicCredentials.from_headers(request.headers) |
140 | 133 | except ValidationError as validation_error: |
141 | 134 | return self.response( |
142 | 135 | TokenErrorResponse( |
143 | 136 | error="invalid_request", |
144 | 137 | error_description=stringify_pydantic_error(validation_error), |
145 | 138 | ) |
146 | 139 | ) |
| 140 | + except AuthenticationError as auth_error: |
| 141 | + return self.response(TokenErrorResponse(error="invalid_request", error_description=auth_error.message)) |
147 | 142 | try: |
148 | 143 | client_info = await self.client_authenticator.authenticate( |
149 | 144 | client_id=credentials.client_id, |
|
0 commit comments