Skip to content

Commit 32a8f69

Browse files
committed
test: add utc datetime validation enforcement to all pydantic models
1 parent 6ae7ec4 commit 32a8f69

1 file changed

Lines changed: 58 additions & 1 deletion

File tree

diracx-core/tests/test_utc_datetime.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
from __future__ import annotations
44

5+
import importlib
6+
import inspect
7+
import pkgutil
58
from datetime import UTC, datetime, timedelta, timezone
69

710
import pytest
811
from pydantic import BaseModel, ValidationError
912

10-
from diracx.core.models._types import UTCDatetime
13+
import diracx.core.models
14+
from diracx.core.models._types import UTCDatetime, _validate_utc
1115

1216

1317
class SampleModel(BaseModel):
@@ -59,3 +63,56 @@ def test_iso_string_non_utc(self):
5963
def test_naive_iso_string(self):
6064
with pytest.raises(ValidationError):
6165
SampleModel(ts="2024-01-01T12:00:00")
66+
67+
68+
def _is_datetime_type(annotation: type) -> bool:
69+
"""Check if an annotation is datetime or a subclass of datetime."""
70+
try:
71+
return isinstance(annotation, type) and issubclass(annotation, datetime)
72+
except TypeError:
73+
return False
74+
75+
76+
def _collect_model_classes() -> list[type[BaseModel]]:
77+
"""Discover all BaseModel subclasses in diracx.core.models."""
78+
models = []
79+
package = diracx.core.models
80+
for _importer, modname, _ispkg in pkgutil.walk_packages(
81+
package.__path__, prefix=package.__name__ + "."
82+
):
83+
if modname.endswith("._types"):
84+
continue
85+
module = importlib.import_module(modname)
86+
for _name, obj in inspect.getmembers(module, inspect.isclass):
87+
if (
88+
issubclass(obj, BaseModel)
89+
and obj is not BaseModel
90+
and obj.__module__ == modname
91+
):
92+
models.append(obj)
93+
return models
94+
95+
96+
def _check_field_uses_utc_validator(model: type[BaseModel], field_name: str) -> bool:
97+
"""Check that a datetime field has the _validate_utc AfterValidator."""
98+
field_info = model.model_fields[field_name]
99+
return any(getattr(m, "func", None) is _validate_utc for m in field_info.metadata)
100+
101+
102+
def test_all_datetime_fields_use_utc_datetime():
103+
"""Ensure no pydantic model in diracx.core.models uses bare datetime.
104+
105+
Every datetime field must use UTCDatetime to enforce UTC validation.
106+
"""
107+
violations = []
108+
for model in _collect_model_classes():
109+
for field_name, field_info in model.model_fields.items():
110+
if not _is_datetime_type(field_info.annotation):
111+
continue
112+
if not _check_field_uses_utc_validator(model, field_name):
113+
violations.append(f"{model.__name__}.{field_name}")
114+
115+
assert not violations, (
116+
"The following fields use bare datetime instead of UTCDatetime:\n"
117+
+ "\n".join(f" - {v}" for v in violations)
118+
)

0 commit comments

Comments
 (0)