|
| 1 | +"""Authlib-backed OAuth2 adapter for MCP HTTPX integration. |
| 2 | +
|
| 3 | +Provides :class:`AuthlibOAuthAdapter`, an ``httpx.Auth`` plugin that wraps |
| 4 | +``authlib.integrations.httpx_client.AsyncOAuth2Client`` to handle token |
| 5 | +acquisition, automatic refresh, and Bearer-header injection. |
| 6 | +
|
| 7 | +The adapter is a drop-in replacement for :class:`~mcp.client.auth.OAuthClientProvider` |
| 8 | +when you already have OAuth endpoints and credentials (i.e. no MCP-specific |
| 9 | +metadata discovery is needed). For full MCP discovery (PRM / OASM / DCR), |
| 10 | +continue to use :class:`~mcp.client.auth.OAuthClientProvider`. |
| 11 | +
|
| 12 | +Supported grant types in this release: |
| 13 | +- ``client_credentials`` — fully self-contained (no browser interaction) |
| 14 | +- ``authorization_code`` + PKCE — requires *redirect_handler* / *callback_handler* |
| 15 | +
|
| 16 | +Example (client_credentials):: |
| 17 | +
|
| 18 | + from mcp.client.auth import AuthlibAdapterConfig, AuthlibOAuthAdapter |
| 19 | +
|
| 20 | + config = AuthlibAdapterConfig( |
| 21 | + token_endpoint="https://auth.example.com/token", |
| 22 | + client_id="my-client", |
| 23 | + client_secret="secret", |
| 24 | + scopes=["read", "write"], |
| 25 | + ) |
| 26 | + adapter = AuthlibOAuthAdapter(config=config, storage=InMemoryTokenStorage()) |
| 27 | + async with httpx.AsyncClient(auth=adapter) as client: |
| 28 | + resp = await client.get("https://api.example.com/resource") |
| 29 | +""" |
| 30 | + |
| 31 | +from __future__ import annotations |
| 32 | + |
| 33 | +import logging |
| 34 | +import secrets |
| 35 | +import string |
| 36 | +from collections.abc import AsyncGenerator, Awaitable, Callable |
| 37 | +from typing import Any |
| 38 | + |
| 39 | +import anyio |
| 40 | +import httpx |
| 41 | +from authlib.integrations.httpx_client import AsyncOAuth2Client # type: ignore[import-untyped] |
| 42 | +from pydantic import BaseModel, Field |
| 43 | + |
| 44 | +from mcp.client.auth.exceptions import OAuthFlowError |
| 45 | +from mcp.client.auth.oauth2 import TokenStorage |
| 46 | +from mcp.shared.auth import OAuthToken |
| 47 | + |
| 48 | +logger = logging.getLogger(__name__) |
| 49 | + |
| 50 | +# --------------------------------------------------------------------------- |
| 51 | +# Configuration |
| 52 | +# --------------------------------------------------------------------------- |
| 53 | + |
| 54 | + |
| 55 | +class AuthlibAdapterConfig(BaseModel): |
| 56 | + """Configuration for :class:`AuthlibOAuthAdapter`. |
| 57 | +
|
| 58 | + Args: |
| 59 | + token_endpoint: URL of the OAuth 2.0 token endpoint (required). |
| 60 | + client_id: OAuth client identifier (required). |
| 61 | + client_secret: OAuth client secret; omit for public clients. |
| 62 | + scopes: List of OAuth scopes to request. |
| 63 | + token_endpoint_auth_method: How to authenticate at the token endpoint. |
| 64 | + Accepted values: ``"client_secret_basic"`` (default), |
| 65 | + ``"client_secret_post"``, ``"none"``. |
| 66 | + authorization_endpoint: URL of the authorization endpoint. When set, |
| 67 | + the adapter uses the *authorization_code + PKCE* grant on 401; when |
| 68 | + ``None`` (default) it uses *client_credentials*. |
| 69 | + redirect_uri: Redirect URI registered with the authorization server. |
| 70 | + Required when *authorization_endpoint* is set. |
| 71 | + leeway: Seconds before token expiry at which automatic refresh is |
| 72 | + triggered (default: 60). |
| 73 | + extra_token_params: Additional key-value pairs forwarded verbatim to |
| 74 | + every ``fetch_token`` call (e.g. ``{"audience": "..."}``). |
| 75 | + """ |
| 76 | + |
| 77 | + token_endpoint: str |
| 78 | + client_id: str |
| 79 | + client_secret: str | None = Field(default=None, repr=False) # excluded from repr to prevent secret leakage |
| 80 | + scopes: list[str] | None = None |
| 81 | + token_endpoint_auth_method: str = "client_secret_basic" |
| 82 | + # authorization_code flow (optional) |
| 83 | + authorization_endpoint: str | None = None |
| 84 | + redirect_uri: str | None = None |
| 85 | + # Authlib tuning |
| 86 | + leeway: int = 60 |
| 87 | + extra_token_params: dict[str, Any] | None = None |
| 88 | + |
| 89 | + |
| 90 | +# --------------------------------------------------------------------------- |
| 91 | +# Adapter |
| 92 | +# --------------------------------------------------------------------------- |
| 93 | + |
| 94 | + |
| 95 | +class AuthlibOAuthAdapter(httpx.Auth): |
| 96 | + """Authlib-backed ``httpx.Auth`` provider. |
| 97 | +
|
| 98 | + Wraps :class:`authlib.integrations.httpx_client.AsyncOAuth2Client` as a |
| 99 | + drop-in ``httpx.Auth`` plugin. Token storage is delegated to the same |
| 100 | + :class:`~mcp.client.auth.TokenStorage` protocol used by the existing |
| 101 | + :class:`~mcp.client.auth.OAuthClientProvider`. |
| 102 | +
|
| 103 | + Args: |
| 104 | + config: Adapter configuration (endpoints, credentials, scopes …). |
| 105 | + storage: Token persistence implementation. |
| 106 | + redirect_handler: Async callback that receives the authorization URL |
| 107 | + and opens it (browser, print, etc.). Required for |
| 108 | + *authorization_code* flow. |
| 109 | + callback_handler: Async callback that waits for the user to complete |
| 110 | + authorization and returns ``(code, state)``. Required for |
| 111 | + *authorization_code* flow. |
| 112 | + """ |
| 113 | + |
| 114 | + requires_response_body = True |
| 115 | + |
| 116 | + def __init__( |
| 117 | + self, |
| 118 | + config: AuthlibAdapterConfig, |
| 119 | + storage: TokenStorage, |
| 120 | + redirect_handler: Callable[[str], Awaitable[None]] | None = None, |
| 121 | + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, |
| 122 | + ) -> None: |
| 123 | + self.config = config |
| 124 | + self.storage = storage |
| 125 | + self.redirect_handler = redirect_handler |
| 126 | + self.callback_handler = callback_handler |
| 127 | + self._lock: anyio.Lock = anyio.Lock() |
| 128 | + self._initialized: bool = False |
| 129 | + |
| 130 | + scope_str = " ".join(config.scopes) if config.scopes else None |
| 131 | + self._client: AsyncOAuth2Client = AsyncOAuth2Client( |
| 132 | + client_id=config.client_id, |
| 133 | + client_secret=config.client_secret, |
| 134 | + scope=scope_str, |
| 135 | + redirect_uri=config.redirect_uri, |
| 136 | + token_endpoint_auth_method=config.token_endpoint_auth_method, |
| 137 | + update_token=self._on_token_update, |
| 138 | + leeway=config.leeway, |
| 139 | + ) |
| 140 | + |
| 141 | + # ------------------------------------------------------------------ |
| 142 | + # Internal helpers |
| 143 | + # ------------------------------------------------------------------ |
| 144 | + |
| 145 | + async def _on_token_update( |
| 146 | + self, |
| 147 | + token: dict[str, Any], |
| 148 | + refresh_token: str | None = None, # noqa: ARG002 (Authlib callback signature) |
| 149 | + access_token: str | None = None, # noqa: ARG002 |
| 150 | + ) -> None: |
| 151 | + """Authlib ``update_token`` callback — persists refreshed tokens.""" |
| 152 | + oauth_token = OAuthToken( |
| 153 | + access_token=token["access_token"], |
| 154 | + token_type=token.get("token_type", "Bearer"), |
| 155 | + expires_in=token.get("expires_in"), |
| 156 | + scope=token.get("scope"), |
| 157 | + refresh_token=token.get("refresh_token"), |
| 158 | + ) |
| 159 | + await self.storage.set_tokens(oauth_token) |
| 160 | + |
| 161 | + async def _initialize(self) -> None: |
| 162 | + """Load persisted tokens into the Authlib client on first use.""" |
| 163 | + stored = await self.storage.get_tokens() |
| 164 | + if stored: |
| 165 | + token_dict: dict[str, Any] = { |
| 166 | + "access_token": stored.access_token, |
| 167 | + "token_type": stored.token_type, |
| 168 | + } |
| 169 | + if stored.refresh_token is not None: |
| 170 | + token_dict["refresh_token"] = stored.refresh_token |
| 171 | + if stored.scope is not None: |
| 172 | + token_dict["scope"] = stored.scope |
| 173 | + if stored.expires_in is not None: |
| 174 | + token_dict["expires_in"] = stored.expires_in |
| 175 | + self._client.token = token_dict |
| 176 | + self._initialized = True |
| 177 | + |
| 178 | + def _build_token_request_params(self) -> dict[str, Any]: |
| 179 | + """Merge base params with any extra params from config.""" |
| 180 | + params: dict[str, Any] = {} |
| 181 | + if self.config.extra_token_params: |
| 182 | + params.update(self.config.extra_token_params) |
| 183 | + return params |
| 184 | + |
| 185 | + async def _fetch_client_credentials_token(self) -> None: |
| 186 | + """Acquire a token via the *client_credentials* grant.""" |
| 187 | + params = self._build_token_request_params() |
| 188 | + await self._client.fetch_token( |
| 189 | + self.config.token_endpoint, |
| 190 | + grant_type="client_credentials", |
| 191 | + **params, |
| 192 | + ) |
| 193 | + if self._client.token: |
| 194 | + await self._on_token_update(dict(self._client.token)) |
| 195 | + |
| 196 | + async def _perform_authorization_code_flow(self) -> None: |
| 197 | + """Acquire a token via *authorization_code + PKCE* grant. |
| 198 | +
|
| 199 | + Raises: |
| 200 | + OAuthFlowError: If *redirect_handler*, *callback_handler*, |
| 201 | + *authorization_endpoint*, or *redirect_uri* are missing. |
| 202 | + """ |
| 203 | + if not self.config.authorization_endpoint: |
| 204 | + raise OAuthFlowError("authorization_endpoint is required for authorization_code flow") |
| 205 | + if not self.config.redirect_uri: |
| 206 | + raise OAuthFlowError("redirect_uri is required for authorization_code flow") |
| 207 | + if self.redirect_handler is None: |
| 208 | + raise OAuthFlowError("redirect_handler is required for authorization_code flow") |
| 209 | + if self.callback_handler is None: |
| 210 | + raise OAuthFlowError("callback_handler is required for authorization_code flow") |
| 211 | + |
| 212 | + # Generate PKCE state + build authorization URL via Authlib |
| 213 | + state = secrets.token_urlsafe(32) |
| 214 | + # Authlib generates code_verifier/code_challenge internally when |
| 215 | + # code_challenge_method is set on the client. |
| 216 | + self._client.code_challenge_method = "S256" |
| 217 | + # Generate a random code_verifier (Authlib will compute the challenge) |
| 218 | + code_verifier = "".join( |
| 219 | + secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128) |
| 220 | + ) |
| 221 | + |
| 222 | + auth_url, _ = self._client.create_authorization_url( |
| 223 | + self.config.authorization_endpoint, |
| 224 | + state=state, |
| 225 | + code_verifier=code_verifier, |
| 226 | + ) |
| 227 | + |
| 228 | + await self.redirect_handler(auth_url) |
| 229 | + auth_code, returned_state = await self.callback_handler() |
| 230 | + |
| 231 | + if returned_state is None or not secrets.compare_digest(returned_state, state): |
| 232 | + raise OAuthFlowError(f"State mismatch: {returned_state!r} != {state!r}") |
| 233 | + if not auth_code: |
| 234 | + raise OAuthFlowError("No authorization code received from callback") |
| 235 | + |
| 236 | + params = self._build_token_request_params() |
| 237 | + await self._client.fetch_token( |
| 238 | + self.config.token_endpoint, |
| 239 | + grant_type="authorization_code", |
| 240 | + code=auth_code, |
| 241 | + redirect_uri=self.config.redirect_uri, |
| 242 | + code_verifier=code_verifier, |
| 243 | + **params, |
| 244 | + ) |
| 245 | + if self._client.token: |
| 246 | + await self._on_token_update(dict(self._client.token)) |
| 247 | + |
| 248 | + def _inject_bearer(self, request: httpx.Request) -> None: |
| 249 | + """Add ``Authorization: Bearer <token>`` header if a token is held.""" |
| 250 | + token = self._client.token |
| 251 | + if token and token.get("access_token"): |
| 252 | + request.headers["Authorization"] = f"Bearer {token['access_token']}" |
| 253 | + |
| 254 | + # ------------------------------------------------------------------ |
| 255 | + # httpx.Auth entry point |
| 256 | + # ------------------------------------------------------------------ |
| 257 | + |
| 258 | + async def async_auth_flow( |
| 259 | + self, request: httpx.Request |
| 260 | + ) -> AsyncGenerator[httpx.Request, httpx.Response]: |
| 261 | + """HTTPX auth flow: ensure a valid token then inject it into the request. |
| 262 | +
|
| 263 | + On a ``401`` response the adapter acquires a fresh token (via |
| 264 | + *client_credentials* or *authorization_code*) and retries once. |
| 265 | + """ |
| 266 | + async with self._lock: |
| 267 | + if not self._initialized: |
| 268 | + await self._initialize() |
| 269 | + |
| 270 | + # Let Authlib auto-refresh if the token is close to expiry |
| 271 | + if self._client.token: |
| 272 | + await self._client.ensure_active_token(self._client.token) |
| 273 | + |
| 274 | + self._inject_bearer(request) |
| 275 | + |
| 276 | + response = yield request |
| 277 | + |
| 278 | + if response.status_code == 401: |
| 279 | + async with self._lock: |
| 280 | + # Acquire a brand-new token |
| 281 | + if self.config.authorization_endpoint: |
| 282 | + await self._perform_authorization_code_flow() |
| 283 | + else: |
| 284 | + await self._fetch_client_credentials_token() |
| 285 | + self._inject_bearer(request) |
| 286 | + |
| 287 | + yield request |
0 commit comments