Skip to content

Commit ce3847e

Browse files
committed
refactor: convert InsertedJob and TokenPayload from TypedDict to BaseModel
TypedDict fields are not validated by pydantic at runtime, so UTCDatetime annotations had no effect. Converting to BaseModel ensures datetime fields are actually validated on construction.
1 parent 32a8f69 commit ce3847e

7 files changed

Lines changed: 39 additions & 39 deletions

File tree

diracx-core/src/diracx/core/models/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel
66
from typing_extensions import TypedDict
77

8-
from ._types import UTCDatetime
8+
from .types import UTCDatetime
99

1010

1111
class UserInfo(BaseModel):
@@ -56,7 +56,7 @@ class OpenIDConfiguration(TypedDict):
5656
code_challenge_methods_supported: list[str]
5757

5858

59-
class TokenPayload(TypedDict):
59+
class TokenPayload(BaseModel):
6060
jti: str
6161
exp: UTCDatetime
6262
dirac_policies: dict

diracx-core/src/diracx/core/models/job.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
from typing import Literal
1010

1111
from pydantic import BaseModel, Field, field_validator
12-
from typing_extensions import TypedDict
1312

14-
from ._types import UTCDatetime
13+
from .types import UTCDatetime
1514

1615

17-
class InsertedJob(TypedDict):
16+
class InsertedJob(BaseModel):
1817
JobID: int
1918
Status: str
2019
MinorStatus: str
File renamed without changes.

diracx-core/tests/test_utc_datetime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic import BaseModel, ValidationError
1212

1313
import diracx.core.models
14-
from diracx.core.models._types import UTCDatetime, _validate_utc
14+
from diracx.core.models.types import UTCDatetime, _validate_utc
1515

1616

1717
class SampleModel(BaseModel):
@@ -80,7 +80,7 @@ def _collect_model_classes() -> list[type[BaseModel]]:
8080
for _importer, modname, _ispkg in pkgutil.walk_packages(
8181
package.__path__, prefix=package.__name__ + "."
8282
):
83-
if modname.endswith("._types"):
83+
if modname.endswith(".types"):
8484
continue
8585
module = importlib.import_module(modname)
8686
for _name, obj in inspect.getmembers(module, inspect.isclass):

diracx-logic/src/diracx/logic/auth/token.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,14 @@ async def exchange_token(
332332
refresh_exp = uuid7_to_datetime(refresh_jti) + timedelta(
333333
minutes=refresh_token_expire_minutes
334334
)
335-
refresh_payload = {
336-
"jti": str(refresh_jti),
337-
"exp": refresh_exp,
335+
refresh_payload = RefreshTokenPayload(
336+
jti=str(refresh_jti),
337+
exp=refresh_exp,
338338
# legacy_exchange is used to indicate that the original refresh token
339339
# was obtained from the legacy_exchange endpoint
340-
"legacy_exchange": legacy_exchange,
341-
"dirac_policies": {},
342-
}
340+
legacy_exchange=legacy_exchange,
341+
dirac_policies={},
342+
)
343343

344344
# Generate access token payload
345345
# For now, the access token is only used to access DIRAC services,
@@ -348,22 +348,22 @@ async def exchange_token(
348348
access_exp = uuid7_to_datetime(access_jti) + timedelta(
349349
minutes=settings.access_token_expire_minutes
350350
)
351-
access_payload: AccessTokenPayload = {
352-
"sub": sub,
353-
"vo": vo,
354-
"iss": settings.token_issuer,
355-
"dirac_properties": list(properties),
356-
"jti": str(access_jti),
357-
"preferred_username": preferred_username,
358-
"dirac_group": dirac_group,
359-
"exp": access_exp,
360-
"dirac_policies": {},
361-
}
351+
access_payload = AccessTokenPayload(
352+
sub=sub,
353+
vo=vo,
354+
iss=settings.token_issuer,
355+
dirac_properties=list(properties),
356+
jti=str(access_jti),
357+
preferred_username=preferred_username,
358+
dirac_group=dirac_group,
359+
exp=access_exp,
360+
dirac_policies={},
361+
)
362362

363363
return access_payload, refresh_payload
364364

365365

366-
def create_token(payload: TokenPayload, settings: AuthSettings) -> str:
366+
def create_token(payload: TokenPayload | dict, settings: AuthSettings) -> str:
367367
"""Create a JWT token with the given payload and settings."""
368368
signing_key = None
369369
for key in settings.token_keystore.jwks.keys:
@@ -377,9 +377,10 @@ def create_token(payload: TokenPayload, settings: AuthSettings) -> str:
377377
if not signing_key:
378378
raise ValueError("No signing key found in JWKS")
379379

380+
claims = payload.model_dump() if isinstance(payload, TokenPayload) else payload
380381
return jwt.encode(
381382
header={"alg": signing_key.get("alg"), "kid": signing_key.get("kid")},
382-
claims=cast(Claims, payload),
383+
claims=cast(Claims, claims),
383384
key=settings.token_keystore.jwks,
384385
algorithms=settings.token_allowed_algorithms,
385386
)

diracx-routers/src/diracx/routers/auth/token.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ async def mint_token(
6161
dirac_refresh_policies[policy_name] = refresh_extra
6262

6363
# Create the access token
64-
access_payload["dirac_policies"] = dirac_access_policies
64+
access_payload.dirac_policies = dirac_access_policies
6565
access_token = create_token(access_payload, settings)
6666

6767
# Create the refresh token
6868
if refresh_payload:
69-
refresh_payload["dirac_policies"] = dirac_refresh_policies
69+
refresh_payload.dirac_policies = dirac_refresh_policies
7070
refresh_token = create_token(refresh_payload, settings)
7171
elif existing_refresh_token:
7272
refresh_token = existing_refresh_token

tests/make_token_local.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,17 @@ def main(token_keystore: str):
3535
jti = uuid7()
3636
expires_at = uuid7_to_datetime(jti) + timedelta(seconds=expires_in)
3737

38-
access_payload: AccessTokenPayload = {
39-
"sub": f"{vo}:{sub}",
40-
"vo": vo,
41-
"iss": settings.token_issuer,
42-
"dirac_properties": dirac_properties,
43-
"jti": str(jti),
44-
"preferred_username": preferred_username,
45-
"dirac_group": dirac_group,
46-
"exp": expires_at,
47-
"dirac_policies": {},
48-
}
38+
access_payload = AccessTokenPayload(
39+
sub=f"{vo}:{sub}",
40+
vo=vo,
41+
iss=settings.token_issuer,
42+
dirac_properties=dirac_properties,
43+
jti=str(jti),
44+
preferred_username=preferred_username,
45+
dirac_group=dirac_group,
46+
exp=expires_at,
47+
dirac_policies={},
48+
)
4949
token = TokenResponse(
5050
access_token=create_token(access_payload, settings),
5151
expires_in=expires_in,

0 commit comments

Comments
 (0)