@@ -30,6 +30,7 @@ def __init__(
3030 client_id : str ,
3131 client_secret : str ,
3232 * ,
33+ cli_client_id : Optional [str ] = None ,
3334 signature_cache_ttl : int = DEFAULT_SIGNATURE_CACHE_TTL ,
3435 openid_configuration : Optional [Dict [str , Any ]] = None ,
3536 ):
@@ -40,12 +41,14 @@ def __init__(
4041 base_url: The OIDC issuer URL (base URL of the authentication server)
4142 client_id: The OAuth2 client ID
4243 client_secret: The OAuth2 client secret
44+ cli_client_id: Optional CLI client ID (for validating tokens from CLI)
4345 signature_cache_ttl: Seconds to cache the OIDC discovery/JWKS responses
4446 openid_configuration: Optional pre-loaded OIDC configuration (used mainly for testing)
4547 """
4648 self .base_url = base_url .rstrip ("/" )
4749 self .client_id = client_id
4850 self .client_secret = client_secret
51+ self .cli_client_id = cli_client_id
4952 self ._discovery = discovery .configure (cache_ttl = signature_cache_ttl )
5053 self ._openid_configuration = openid_configuration
5154
@@ -64,20 +67,55 @@ async def _get_algorithms(self) -> List[str]:
6467 oidc_config = await self ._get_openid_configuration ()
6568 return await asyncio .to_thread (self ._discovery .signing_algos , oidc_config )
6669
70+ def _get_valid_audiences (self ) -> List [str ]:
71+ """Get list of valid audiences (client IDs) for token validation."""
72+ audiences = [self .client_id ]
73+ if self .cli_client_id :
74+ audiences .append (self .cli_client_id )
75+ return audiences
76+
6777 async def _decode_token (
6878 self , token : str , * , audience : Optional [str ] = None
6979 ) -> Dict [str , Any ]:
7080 oidc_config = await self ._get_openid_configuration ()
7181 jwks = await self ._get_jwks ()
7282 algorithms = await self ._get_algorithms ()
73- return jwt .decode (
74- token ,
75- jwks ,
76- algorithms = algorithms ,
77- audience = audience or self .client_id ,
78- issuer = oidc_config .get ("issuer" , self .base_url ),
79- options = {"verify_at_hash" : False },
80- )
83+ issuer = oidc_config .get ("issuer" , self .base_url )
84+
85+ if audience :
86+ # Single audience specified, use it directly
87+ return jwt .decode (
88+ token ,
89+ jwks ,
90+ algorithms = algorithms ,
91+ audience = audience ,
92+ issuer = issuer ,
93+ options = {"verify_at_hash" : False },
94+ )
95+
96+ # Try each valid audience until one succeeds
97+ valid_audiences = self ._get_valid_audiences ()
98+ last_error = None
99+ for aud in valid_audiences :
100+ try :
101+ return jwt .decode (
102+ token ,
103+ jwks ,
104+ algorithms = algorithms ,
105+ audience = aud ,
106+ issuer = issuer ,
107+ options = {"verify_at_hash" : False },
108+ )
109+ except jwt .JWTClaimsError as e :
110+ # Audience mismatch, try next
111+ last_error = e
112+ continue
113+ except Exception :
114+ # Other errors (signature, expiration, etc.) should fail immediately
115+ raise
116+
117+ # None of the audiences matched
118+ raise last_error or jwt .JWTClaimsError ("Invalid audience" )
81119
82120 async def get_auth_url (
83121 self , redirect_uri : str , scope : List [str ], state : Optional [str ] = None
@@ -165,8 +203,16 @@ async def validate_access_token(self, token: str) -> bool:
165203 Raises:
166204 Exception if validation fails
167205 """
168- await self ._decode_token (token )
169- return True
206+ try :
207+ result = await self ._decode_token (token )
208+ print (f"Token validated: { result } " )
209+ return True
210+ except Exception as e :
211+ oidc_config = await self ._get_openid_configuration ()
212+ print (f"Validation failed: { e } " )
213+ print (f"Expected issuer: { oidc_config .get ('issuer' , self .base_url )} " )
214+ print (f"Expected audiences: { self ._get_valid_audiences ()} " )
215+ raise
170216
171217 async def get_user_info (self , access_token : str ) -> Dict [str , Any ]:
172218 """
0 commit comments