|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Tests for scripts/regenerate_templates.py. |
| 5 | +
|
| 6 | +`_dump_defaults` must extract defaults without constructing nested |
| 7 | +BaseModels that appear as default_factory, because construction runs |
| 8 | +validators (which may have platform-dependent side effects). |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import importlib.util |
| 14 | +import sys |
| 15 | +from pathlib import Path |
| 16 | + |
| 17 | +from pydantic import BaseModel, Field, model_validator |
| 18 | + |
| 19 | +_REPO_ROOT = Path(__file__).resolve().parents[3] |
| 20 | +_SCRIPT = _REPO_ROOT / "scripts" / "regenerate_templates.py" |
| 21 | + |
| 22 | + |
| 23 | +def _load_regenerate_templates(): |
| 24 | + """Load scripts/regenerate_templates.py as a module (it is not a package).""" |
| 25 | + if "regenerate_templates" in sys.modules: |
| 26 | + return sys.modules["regenerate_templates"] |
| 27 | + spec = importlib.util.spec_from_file_location("regenerate_templates", _SCRIPT) |
| 28 | + assert spec and spec.loader |
| 29 | + module = importlib.util.module_from_spec(spec) |
| 30 | + sys.modules["regenerate_templates"] = module |
| 31 | + spec.loader.exec_module(module) |
| 32 | + return module |
| 33 | + |
| 34 | + |
| 35 | +class TestDumpDefaultsSkipsBaseModelFactory: |
| 36 | + def test_basemodel_factory_does_not_run_validator(self): |
| 37 | + """default_factory=<BaseModel subclass> must not invoke the model's validators.""" |
| 38 | + rt = _load_regenerate_templates() |
| 39 | + |
| 40 | + call_count = 0 |
| 41 | + |
| 42 | + class Inner(BaseModel): |
| 43 | + x: int = 42 |
| 44 | + |
| 45 | + @model_validator(mode="after") |
| 46 | + def _count(self): |
| 47 | + nonlocal call_count |
| 48 | + call_count += 1 |
| 49 | + return self |
| 50 | + |
| 51 | + class Outer(BaseModel): |
| 52 | + inner: Inner = Field(default_factory=Inner) |
| 53 | + |
| 54 | + # Sanity: constructing Inner() directly does invoke the validator. |
| 55 | + Inner() |
| 56 | + assert call_count == 1 |
| 57 | + |
| 58 | + call_count = 0 |
| 59 | + result = rt._dump_defaults(Outer) |
| 60 | + |
| 61 | + assert call_count == 0, ( |
| 62 | + "Inner validator was invoked — _dump_defaults called the factory " |
| 63 | + "instead of recursing." |
| 64 | + ) |
| 65 | + assert result == {"inner": {"x": 42}} |
| 66 | + |
| 67 | + def test_callable_factory_is_still_invoked(self): |
| 68 | + """Factories that are callables (not BaseModel subclasses) must still be called.""" |
| 69 | + rt = _load_regenerate_templates() |
| 70 | + |
| 71 | + class Config(BaseModel): |
| 72 | + tags: list[str] = Field(default_factory=lambda: ["default-tag"]) |
| 73 | + |
| 74 | + result = rt._dump_defaults(Config) |
| 75 | + assert result == {"tags": ["default-tag"]} |
0 commit comments