@@ -204,6 +204,9 @@ def __init__(
204204 )
205205 self ._initialized = False
206206
207+ # State parameter for CSRF protection
208+ self .auth_state : str | None = None
209+
207210 def _extract_resource_metadata_from_www_auth (self , init_response : httpx .Response ) -> str | None :
208211 """
209212 Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
@@ -322,13 +325,13 @@ async def _perform_authorization(self) -> tuple[str, str]:
322325
323326 # Generate PKCE parameters
324327 pkce_params = PKCEParameters .generate ()
325- state = secrets .token_urlsafe (32 )
328+ self . auth_state = secrets .token_urlsafe (32 )
326329
327330 auth_params = {
328331 "response_type" : "code" ,
329332 "client_id" : self .context .client_info .client_id ,
330333 "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
331- "state" : state ,
334+ "state" : self . auth_state ,
332335 "code_challenge" : pkce_params .code_challenge ,
333336 "code_challenge_method" : "S256" ,
334337 }
@@ -346,8 +349,10 @@ async def _perform_authorization(self) -> tuple[str, str]:
346349 # Wait for callback
347350 auth_code , returned_state = await self .context .callback_handler ()
348351
349- if returned_state is None or not secrets .compare_digest (returned_state , state ):
350- raise OAuthFlowError (f"State parameter mismatch: { returned_state } != { state } " )
352+ if returned_state is None or not secrets .compare_digest (returned_state , self .auth_state ):
353+ raise OAuthFlowError (f"State parameter mismatch: { returned_state } != { self .auth_state } " )
354+
355+ self .auth_state = None
351356
352357 if not auth_code :
353358 raise OAuthFlowError ("No authorization code received" )
0 commit comments