Skip to content

Commit 873bbe0

Browse files
committed
Implement client credentials auth flow
1 parent 679b229 commit 873bbe0

File tree

2 files changed

+62
-27
lines changed

2 files changed

+62
-27
lines changed

src/mcp/client/auth.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
from collections.abc import AsyncGenerator, Awaitable, Callable
1414
from dataclasses import dataclass, field
15-
from typing import Protocol
15+
from typing import Optional, Protocol
1616
from urllib.parse import urlencode, urljoin, urlparse
1717

1818
import anyio
@@ -87,8 +87,8 @@ class OAuthContext:
8787
server_url: str
8888
client_metadata: OAuthClientMetadata
8989
storage: TokenStorage
90-
redirect_handler: Callable[[str], Awaitable[None]]
91-
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]]
90+
redirect_handler: Optional[Callable[[str], Awaitable[None]]]
91+
callback_handler: Optional[Callable[[], Awaitable[tuple[str, str | None]]]]
9292
timeout: float = 300.0
9393

9494
# Discovered metadata
@@ -164,8 +164,8 @@ def __init__(
164164
server_url: str,
165165
client_metadata: OAuthClientMetadata,
166166
storage: TokenStorage,
167-
redirect_handler: Callable[[str], Awaitable[None]],
168-
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]],
167+
redirect_handler: Optional[Callable[[str], Awaitable[None]]] = None,
168+
callback_handler: Optional[Callable[[], Awaitable[tuple[str, str | None]]]] = None,
169169
timeout: float = 300.0,
170170
):
171171
"""Initialize OAuth2 authentication."""
@@ -250,8 +250,27 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
250250
except ValidationError as e:
251251
raise OAuthRegistrationError(f"Invalid registration response: {e}")
252252

253-
async def _perform_authorization(self) -> tuple[str, str]:
253+
async def _perform_authorization(self) -> httpx.Request:
254+
"""Perform the authorization flow."""
255+
if not self.context.client_info:
256+
raise OAuthFlowError("No client info available for authorization")
257+
258+
if "client_credentials" in self.context.client_info.grant_types:
259+
token_request = await self._exchange_token_client_credentials()
260+
return token_request
261+
pass
262+
else:
263+
auth_code, code_verifier = await self._perform_authorization_code_grant()
264+
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
265+
return token_request
266+
267+
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
254268
"""Perform the authorization redirect and get auth code."""
269+
if not self.context.redirect_handler:
270+
raise OAuthFlowError("No redirect handler provided for authorization code grant")
271+
if not self.context.callback_handler:
272+
raise OAuthFlowError("No callback handler provided for authorization code grant")
273+
255274
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
256275
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint)
257276
else:
@@ -293,8 +312,8 @@ async def _perform_authorization(self) -> tuple[str, str]:
293312
# Return auth code and code verifier for token exchange
294313
return auth_code, pkce_params.code_verifier
295314

296-
async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request:
297-
"""Build token exchange request."""
315+
async def _exchange_token_authorization_code(self, auth_code: str, code_verifier: str) -> httpx.Request:
316+
"""Build token exchange request for authorization_code flow."""
298317
if not self.context.client_info:
299318
raise OAuthFlowError("Missing client info")
300319

@@ -320,6 +339,31 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
320339
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
321340
)
322341

342+
async def _exchange_token_client_credentials(self) -> httpx.Request:
343+
"""Build token exchange request for client_credentials flow."""
344+
if not self.context.client_info:
345+
raise OAuthFlowError("Missing client info")
346+
347+
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
348+
token_url = str(self.context.oauth_metadata.token_endpoint)
349+
else:
350+
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
351+
token_url = urljoin(auth_base_url, "/token")
352+
353+
token_data = {
354+
"grant_type": "client_credentials",
355+
"resource": self.context.get_resource_url(), # RFC 8707
356+
}
357+
358+
if self.context.client_info.client_id:
359+
token_data["client_id"] = self.context.client_info.client_id
360+
if self.context.client_info.client_secret:
361+
token_data["client_secret"] = self.context.client_info.client_secret
362+
363+
return httpx.Request(
364+
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
365+
)
366+
323367
async def _handle_token_response(self, response: httpx.Response) -> None:
324368
"""Handle token exchange response."""
325369
if response.status_code != 200:
@@ -429,12 +473,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
429473
registration_response = yield registration_request
430474
await self._handle_registration_response(registration_response)
431475

432-
# Step 4: Perform authorization
433-
auth_code, code_verifier = await self._perform_authorization()
434-
435-
# Step 5: Exchange authorization code for tokens
436-
token_request = await self._exchange_token(auth_code, code_verifier)
437-
token_response = yield token_request
476+
# Step 4: Perform authorization and complete token exchange
477+
token_response = yield await self._perform_authorization()
438478
await self._handle_token_response(token_response)
439479
except Exception as e:
440480
logger.error(f"OAuth flow error: {e}")
@@ -475,12 +515,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
475515
registration_response = yield registration_request
476516
await self._handle_registration_response(registration_response)
477517

478-
# Step 4: Perform authorization
479-
auth_code, code_verifier = await self._perform_authorization()
480-
481-
# Step 5: Exchange authorization code for tokens
482-
token_request = await self._exchange_token(auth_code, code_verifier)
483-
token_response = yield token_request
518+
# Step 4: Perform authorization and complete token exchange
519+
token_response = yield await self._perform_authorization()
484520
await self._handle_token_response(token_response)
485521

486522
# Retry with new tokens

src/mcp/shared/auth.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@ class OAuthClientMetadata(BaseModel):
4242
"""
4343

4444
redirect_uris: list[AnyUrl] = Field(..., min_length=1)
45-
# token_endpoint_auth_method: this implementation only supports none &
46-
# client_secret_post;
47-
# ie: we do not support client_secret_basic
48-
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
49-
# grant_types: this implementation only supports authorization_code & refresh_token
50-
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
45+
# supported auth methods for the token endpoint
46+
token_endpoint_auth_method: Literal["none", "client_secret_basic", "client_secret_post"] = "client_secret_post"
47+
# supported grant_types of this implementation
48+
grant_types: list[Literal["authorization_code", "client_credentials", "refresh_token"]] = [
5149
"authorization_code",
50+
"client_credentials",
5251
"refresh_token",
5352
]
5453
# this implementation only supports code; ie: it does not support implicit grants
@@ -96,7 +95,7 @@ class OAuthClientInformationFull(OAuthClientMetadata):
9695
(client information plus metadata).
9796
"""
9897

99-
client_id: str
98+
client_id: str | None = None
10099
client_secret: str | None = None
101100
client_id_issued_at: int | None = None
102101
client_secret_expires_at: int | None = None

0 commit comments

Comments
 (0)