@@ -332,14 +332,14 @@ async def exchange_token(
332332 refresh_exp = uuid7_to_datetime (refresh_jti ) + timedelta (
333333 minutes = refresh_token_expire_minutes
334334 )
335- refresh_payload = {
336- " jti" : str (refresh_jti ),
337- " exp" : refresh_exp ,
335+ refresh_payload = RefreshTokenPayload (
336+ jti = str (refresh_jti ),
337+ exp = refresh_exp ,
338338 # legacy_exchange is used to indicate that the original refresh token
339339 # was obtained from the legacy_exchange endpoint
340- " legacy_exchange" : legacy_exchange ,
341- " dirac_policies" : {},
342- }
340+ legacy_exchange = legacy_exchange ,
341+ dirac_policies = {},
342+ )
343343
344344 # Generate access token payload
345345 # For now, the access token is only used to access DIRAC services,
@@ -348,22 +348,22 @@ async def exchange_token(
348348 access_exp = uuid7_to_datetime (access_jti ) + timedelta (
349349 minutes = settings .access_token_expire_minutes
350350 )
351- access_payload : AccessTokenPayload = {
352- " sub" : sub ,
353- "vo" : vo ,
354- " iss" : settings .token_issuer ,
355- " dirac_properties" : list (properties ),
356- " jti" : str (access_jti ),
357- " preferred_username" : preferred_username ,
358- " dirac_group" : dirac_group ,
359- " exp" : access_exp ,
360- " dirac_policies" : {},
361- }
351+ access_payload = AccessTokenPayload (
352+ sub = sub ,
353+ vo = vo ,
354+ iss = settings .token_issuer ,
355+ dirac_properties = list (properties ),
356+ jti = str (access_jti ),
357+ preferred_username = preferred_username ,
358+ dirac_group = dirac_group ,
359+ exp = access_exp ,
360+ dirac_policies = {},
361+ )
362362
363363 return access_payload , refresh_payload
364364
365365
366- def create_token (payload : TokenPayload , settings : AuthSettings ) -> str :
366+ def create_token (payload : TokenPayload | dict , settings : AuthSettings ) -> str :
367367 """Create a JWT token with the given payload and settings."""
368368 signing_key = None
369369 for key in settings .token_keystore .jwks .keys :
@@ -377,9 +377,10 @@ def create_token(payload: TokenPayload, settings: AuthSettings) -> str:
377377 if not signing_key :
378378 raise ValueError ("No signing key found in JWKS" )
379379
380+ claims = payload .model_dump () if isinstance (payload , TokenPayload ) else payload
380381 return jwt .encode (
381382 header = {"alg" : signing_key .get ("alg" ), "kid" : signing_key .get ("kid" )},
382- claims = cast (Claims , payload ),
383+ claims = cast (Claims , claims ),
383384 key = settings .token_keystore .jwks ,
384385 algorithms = settings .token_allowed_algorithms ,
385386 )
0 commit comments