diff --git a/scripts/regenerate_templates.py b/scripts/regenerate_templates.py index eb84a6dd..30d0bb0d 100644 --- a/scripts/regenerate_templates.py +++ b/scripts/regenerate_templates.py @@ -149,6 +149,17 @@ def _dump_defaults(model: type[BaseModel]) -> dict: if info.default is not PydanticUndefined: default = info.default elif info.default_factory is not None: + # If the factory is itself a BaseModel subclass (e.g. + # default_factory=HTTPClientConfig), recurse into it instead of + # calling it — calling would run validators, defeating the point + # of this function. Factories that dynamically pick a concrete + # subclass (e.g. TransportConfig.create_default → ZMQTransportConfig) + # aren't types, so they fall through and get called as before. + if isinstance(info.default_factory, type) and issubclass( + info.default_factory, BaseModel + ): + out[name] = _dump_defaults(info.default_factory) + continue default = info.default_factory() else: # Required field — recurse if BaseModel, else None diff --git a/src/inference_endpoint/endpoint_client/config.py b/src/inference_endpoint/endpoint_client/config.py index c5509fc2..599ca7c4 100644 --- a/src/inference_endpoint/endpoint_client/config.py +++ b/src/inference_endpoint/endpoint_client/config.py @@ -37,7 +37,12 @@ from .accumulator_protocol import SSEAccumulatorProtocol from .adapter_protocol import HttpRequestAdapter -from .cpu_affinity import AffinityPlan, get_cpus_in_numa_node, get_current_numa_node +from .cpu_affinity import ( + AffinityPlan, + UnsupportedPlatformError, + get_cpus_in_numa_node, + get_current_numa_node, +) from .utils import get_ephemeral_port_limit, get_ephemeral_port_range ADAPTER_MAP = { @@ -262,17 +267,24 @@ def _get_auto_num_workers() -> int: Users can override with explicit num_workers to use more cores (workers will be pinned to additional cores outside NUMA domain if needed). + On non-Linux platforms (NUMA probing is Linux-only) falls back to + ``min_workers`` so the config can still be constructed for local + development, template regeneration, and tests. + Returns: Number of workers to use when num_workers is -1 (auto). """ min_workers = 10 max_workers = 24 - numa_node = get_current_numa_node() - if numa_node is None: + try: + numa_node = get_current_numa_node() + if numa_node is None: + return min_workers + numa_cpus = get_cpus_in_numa_node(numa_node) + except UnsupportedPlatformError: return min_workers - numa_cpus = get_cpus_in_numa_node(numa_node) if not numa_cpus: return min_workers diff --git a/tests/unit/config/test_regenerate_templates.py b/tests/unit/config/test_regenerate_templates.py new file mode 100644 index 00000000..c40ece81 --- /dev/null +++ b/tests/unit/config/test_regenerate_templates.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for scripts/regenerate_templates.py. + +`_dump_defaults` must extract defaults without constructing nested +BaseModels that appear as default_factory, because construction runs +validators (which may have platform-dependent side effects). +""" + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path + +from pydantic import BaseModel, Field, model_validator + +_REPO_ROOT = Path(__file__).resolve().parents[3] +_SCRIPT = _REPO_ROOT / "scripts" / "regenerate_templates.py" + + +def _load_regenerate_templates(): + """Load scripts/regenerate_templates.py as a module (it is not a package).""" + if "regenerate_templates" in sys.modules: + return sys.modules["regenerate_templates"] + spec = importlib.util.spec_from_file_location("regenerate_templates", _SCRIPT) + assert spec and spec.loader + module = importlib.util.module_from_spec(spec) + sys.modules["regenerate_templates"] = module + spec.loader.exec_module(module) + return module + + +class TestDumpDefaultsSkipsBaseModelFactory: + def test_basemodel_factory_does_not_run_validator(self): + """default_factory= must not invoke the model's validators.""" + rt = _load_regenerate_templates() + + call_count = 0 + + class Inner(BaseModel): + x: int = 42 + + @model_validator(mode="after") + def _count(self): + nonlocal call_count + call_count += 1 + return self + + class Outer(BaseModel): + inner: Inner = Field(default_factory=Inner) + + # Sanity: constructing Inner() directly does invoke the validator. + Inner() + assert call_count == 1 + + call_count = 0 + result = rt._dump_defaults(Outer) + + assert call_count == 0, ( + "Inner validator was invoked — _dump_defaults called the factory " + "instead of recursing." + ) + assert result == {"inner": {"x": 42}} + + def test_callable_factory_is_still_invoked(self): + """Factories that are callables (not BaseModel subclasses) must still be called.""" + rt = _load_regenerate_templates() + + class Config(BaseModel): + tags: list[str] = Field(default_factory=lambda: ["default-tag"]) + + result = rt._dump_defaults(Config) + assert result == {"tags": ["default-tag"]} diff --git a/tests/unit/endpoint_client/test_http_client_config.py b/tests/unit/endpoint_client/test_http_client_config.py new file mode 100644 index 00000000..22e251f3 --- /dev/null +++ b/tests/unit/endpoint_client/test_http_client_config.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for HTTPClientConfig construction on non-Linux platforms. + +NUMA probing is Linux-only; auto-detecting num_workers must fall back +gracefully so HTTPClientConfig() can be constructed anywhere. +""" + +from unittest.mock import patch + +from inference_endpoint.endpoint_client import config as cfg +from inference_endpoint.endpoint_client.cpu_affinity import UnsupportedPlatformError + + +class TestAutoNumWorkersNonLinux: + def _clear_cache(self): + cfg._get_auto_num_workers.cache_clear() + + def test_get_current_numa_node_unsupported_falls_back_to_min(self): + self._clear_cache() + with patch.object( + cfg, "get_current_numa_node", side_effect=UnsupportedPlatformError("darwin") + ): + assert cfg._get_auto_num_workers() == 10 + + def test_get_cpus_in_numa_node_unsupported_falls_back_to_min(self): + self._clear_cache() + with ( + patch.object(cfg, "get_current_numa_node", return_value=0), + patch.object( + cfg, + "get_cpus_in_numa_node", + side_effect=UnsupportedPlatformError("darwin"), + ), + ): + assert cfg._get_auto_num_workers() == 10 + + def test_http_client_config_constructs_when_numa_unsupported(self): + self._clear_cache() + with patch.object( + cfg, "get_current_numa_node", side_effect=UnsupportedPlatformError("darwin") + ): + c = cfg.HTTPClientConfig() + assert c.num_workers == 10