1212import time
1313from collections .abc import AsyncGenerator , Awaitable , Callable
1414from dataclasses import dataclass , field
15- from typing import Protocol
15+ from typing import Optional , Protocol
1616from urllib .parse import urlencode , urljoin , urlparse
1717
1818import anyio
@@ -87,8 +87,8 @@ class OAuthContext:
8787 server_url : str
8888 client_metadata : OAuthClientMetadata
8989 storage : TokenStorage
90- redirect_handler : Callable [[str ], Awaitable [None ]]
91- callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]]
90+ redirect_handler : Optional [ Callable [[str ], Awaitable [None ] ]]
91+ callback_handler : Optional [ Callable [[], Awaitable [tuple [str , str | None ] ]]]
9292 timeout : float = 300.0
9393
9494 # Discovered metadata
@@ -164,8 +164,8 @@ def __init__(
164164 server_url : str ,
165165 client_metadata : OAuthClientMetadata ,
166166 storage : TokenStorage ,
167- redirect_handler : Callable [[str ], Awaitable [None ]],
168- callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]],
167+ redirect_handler : Optional [ Callable [[str ], Awaitable [None ]]] = None ,
168+ callback_handler : Optional [ Callable [[], Awaitable [tuple [str , str | None ]]]] = None ,
169169 timeout : float = 300.0 ,
170170 ):
171171 """Initialize OAuth2 authentication."""
@@ -250,8 +250,27 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
250250 except ValidationError as e :
251251 raise OAuthRegistrationError (f"Invalid registration response: { e } " )
252252
253- async def _perform_authorization (self ) -> tuple [str , str ]:
253+ async def _perform_authorization (self ) -> httpx .Request :
254+ """Perform the authorization flow."""
255+ if not self .context .client_info :
256+ raise OAuthFlowError ("No client info available for authorization" )
257+
258+ if "client_credentials" in self .context .client_info .grant_types :
259+ token_request = await self ._exchange_token_client_credentials ()
260+ return token_request
261+ pass
262+ else :
263+ auth_code , code_verifier = await self ._perform_authorization_code_grant ()
264+ token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
265+ return token_request
266+
267+ async def _perform_authorization_code_grant (self ) -> tuple [str , str ]:
254268 """Perform the authorization redirect and get auth code."""
269+ if not self .context .redirect_handler :
270+ raise OAuthFlowError ("No redirect handler provided for authorization code grant" )
271+ if not self .context .callback_handler :
272+ raise OAuthFlowError ("No callback handler provided for authorization code grant" )
273+
255274 if self .context .oauth_metadata and self .context .oauth_metadata .authorization_endpoint :
256275 auth_endpoint = str (self .context .oauth_metadata .authorization_endpoint )
257276 else :
@@ -293,8 +312,8 @@ async def _perform_authorization(self) -> tuple[str, str]:
293312 # Return auth code and code verifier for token exchange
294313 return auth_code , pkce_params .code_verifier
295314
296- async def _exchange_token (self , auth_code : str , code_verifier : str ) -> httpx .Request :
297- """Build token exchange request."""
315+ async def _exchange_token_authorization_code (self , auth_code : str , code_verifier : str ) -> httpx .Request :
316+ """Build token exchange request for authorization_code flow ."""
298317 if not self .context .client_info :
299318 raise OAuthFlowError ("Missing client info" )
300319
@@ -320,6 +339,31 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
320339 "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
321340 )
322341
342+ async def _exchange_token_client_credentials (self ) -> httpx .Request :
343+ """Build token exchange request for client_credentials flow."""
344+ if not self .context .client_info :
345+ raise OAuthFlowError ("Missing client info" )
346+
347+ if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
348+ token_url = str (self .context .oauth_metadata .token_endpoint )
349+ else :
350+ auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
351+ token_url = urljoin (auth_base_url , "/token" )
352+
353+ token_data = {
354+ "grant_type" : "client_credentials" ,
355+ "resource" : self .context .get_resource_url (), # RFC 8707
356+ }
357+
358+ if self .context .client_info .client_id :
359+ token_data ["client_id" ] = self .context .client_info .client_id
360+ if self .context .client_info .client_secret :
361+ token_data ["client_secret" ] = self .context .client_info .client_secret
362+
363+ return httpx .Request (
364+ "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
365+ )
366+
323367 async def _handle_token_response (self , response : httpx .Response ) -> None :
324368 """Handle token exchange response."""
325369 if response .status_code != 200 :
@@ -429,12 +473,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
429473 registration_response = yield registration_request
430474 await self ._handle_registration_response (registration_response )
431475
432- # Step 4: Perform authorization
433- auth_code , code_verifier = await self ._perform_authorization ()
434-
435- # Step 5: Exchange authorization code for tokens
436- token_request = await self ._exchange_token (auth_code , code_verifier )
437- token_response = yield token_request
476+ # Step 4: Perform authorization and complete token exchange
477+ token_response = yield await self ._perform_authorization ()
438478 await self ._handle_token_response (token_response )
439479 except Exception as e :
440480 logger .error (f"OAuth flow error: { e } " )
@@ -475,12 +515,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
475515 registration_response = yield registration_request
476516 await self ._handle_registration_response (registration_response )
477517
478- # Step 4: Perform authorization
479- auth_code , code_verifier = await self ._perform_authorization ()
480-
481- # Step 5: Exchange authorization code for tokens
482- token_request = await self ._exchange_token (auth_code , code_verifier )
483- token_response = yield token_request
518+ # Step 4: Perform authorization and complete token exchange
519+ token_response = yield await self ._perform_authorization ()
484520 await self ._handle_token_response (token_response )
485521
486522 # Retry with new tokens
0 commit comments