|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import importlib |
| 6 | +import inspect |
| 7 | +import pkgutil |
5 | 8 | from datetime import UTC, datetime, timedelta, timezone |
6 | 9 |
|
7 | 10 | import pytest |
8 | 11 | from pydantic import BaseModel, ValidationError |
9 | 12 |
|
10 | | -from diracx.core.models._types import UTCDatetime |
| 13 | +import diracx.core.models |
| 14 | +from diracx.core.models._types import UTCDatetime, _validate_utc |
11 | 15 |
|
12 | 16 |
|
13 | 17 | class SampleModel(BaseModel): |
@@ -59,3 +63,56 @@ def test_iso_string_non_utc(self): |
59 | 63 | def test_naive_iso_string(self): |
60 | 64 | with pytest.raises(ValidationError): |
61 | 65 | SampleModel(ts="2024-01-01T12:00:00") |
| 66 | + |
| 67 | + |
| 68 | +def _is_datetime_type(annotation: type) -> bool: |
| 69 | + """Check if an annotation is datetime or a subclass of datetime.""" |
| 70 | + try: |
| 71 | + return isinstance(annotation, type) and issubclass(annotation, datetime) |
| 72 | + except TypeError: |
| 73 | + return False |
| 74 | + |
| 75 | + |
| 76 | +def _collect_model_classes() -> list[type[BaseModel]]: |
| 77 | + """Discover all BaseModel subclasses in diracx.core.models.""" |
| 78 | + models = [] |
| 79 | + package = diracx.core.models |
| 80 | + for _importer, modname, _ispkg in pkgutil.walk_packages( |
| 81 | + package.__path__, prefix=package.__name__ + "." |
| 82 | + ): |
| 83 | + if modname.endswith("._types"): |
| 84 | + continue |
| 85 | + module = importlib.import_module(modname) |
| 86 | + for _name, obj in inspect.getmembers(module, inspect.isclass): |
| 87 | + if ( |
| 88 | + issubclass(obj, BaseModel) |
| 89 | + and obj is not BaseModel |
| 90 | + and obj.__module__ == modname |
| 91 | + ): |
| 92 | + models.append(obj) |
| 93 | + return models |
| 94 | + |
| 95 | + |
| 96 | +def _check_field_uses_utc_validator(model: type[BaseModel], field_name: str) -> bool: |
| 97 | + """Check that a datetime field has the _validate_utc AfterValidator.""" |
| 98 | + field_info = model.model_fields[field_name] |
| 99 | + return any(getattr(m, "func", None) is _validate_utc for m in field_info.metadata) |
| 100 | + |
| 101 | + |
| 102 | +def test_all_datetime_fields_use_utc_datetime(): |
| 103 | + """Ensure no pydantic model in diracx.core.models uses bare datetime. |
| 104 | +
|
| 105 | + Every datetime field must use UTCDatetime to enforce UTC validation. |
| 106 | + """ |
| 107 | + violations = [] |
| 108 | + for model in _collect_model_classes(): |
| 109 | + for field_name, field_info in model.model_fields.items(): |
| 110 | + if not _is_datetime_type(field_info.annotation): |
| 111 | + continue |
| 112 | + if not _check_field_uses_utc_validator(model, field_name): |
| 113 | + violations.append(f"{model.__name__}.{field_name}") |
| 114 | + |
| 115 | + assert not violations, ( |
| 116 | + "The following fields use bare datetime instead of UTCDatetime:\n" |
| 117 | + + "\n".join(f" - {v}" for v in violations) |
| 118 | + ) |
0 commit comments