Skip to content

Commit ed1704e

Browse files
viraatcclaude
andauthored
fix: let HTTPClientConfig and template regen work on non-Linux (#291)
* fix: let HTTPClientConfig and template regen work on non-Linux Two independent issues that together broke pre-commit (and any HTTPClientConfig construction) on macOS: 1. HTTPClientConfig()._resolve_defaults calls _get_auto_num_workers when num_workers=-1, which invokes get_current_numa_node(). That function is @require_linux and raises UnsupportedPlatformError on darwin. Fix: catch UnsupportedPlatformError and fall back to min_workers=10, matching the existing "NUMA not discoverable" branch. 2. scripts/regenerate_templates.py::_dump_defaults was documented to avoid running model validators, but called default_factory() for every field — which constructs the nested model (and runs its validators) whenever the factory is a BaseModel subclass. Fix: when the factory is a BaseModel subclass, recurse into _dump_defaults(factory) instead. Factories that dynamically pick a concrete subclass (e.g. TransportConfig.create_default -> ZMQTransportConfig) are not types, so they still get called as before and the concrete subclass is walked via the existing isinstance(default, BaseModel) branch. Verified on Linux: pre-commit passes, templates are byte-identical. Verified against simulated non-Linux (UnsupportedPlatformError patched to always raise): HTTPClientConfig() constructs, and the regen script completes without entering the NUMA code path at all. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: cover non-Linux num_workers fallback and _dump_defaults recursion Adds unit tests for both fixes in this PR: - tests/unit/endpoint_client/test_http_client_config.py: patches get_current_numa_node / get_cpus_in_numa_node to raise UnsupportedPlatformError and asserts _get_auto_num_workers and HTTPClientConfig() fall back to min_workers=10. - tests/unit/config/test_regenerate_templates.py: defines a BaseModel with a counter-incrementing model_validator used as default_factory, and asserts _dump_defaults does not invoke the validator while still emitting the nested defaults. Also covers the non-BaseModel callable factory path (lambda) to ensure it still gets called. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore: remove accidentally-committed local files Personal editor/settings files (.claude.local.md, .nvimrc.lua) were pulled in by git add -A in the previous commit. They should stay untracked. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c980656 commit ed1704e

4 files changed

Lines changed: 147 additions & 4 deletions

File tree

scripts/regenerate_templates.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ def _dump_defaults(model: type[BaseModel]) -> dict:
149149
if info.default is not PydanticUndefined:
150150
default = info.default
151151
elif info.default_factory is not None:
152+
# If the factory is itself a BaseModel subclass (e.g.
153+
# default_factory=HTTPClientConfig), recurse into it instead of
154+
# calling it — calling would run validators, defeating the point
155+
# of this function. Factories that dynamically pick a concrete
156+
# subclass (e.g. TransportConfig.create_default → ZMQTransportConfig)
157+
# aren't types, so they fall through and get called as before.
158+
if isinstance(info.default_factory, type) and issubclass(
159+
info.default_factory, BaseModel
160+
):
161+
out[name] = _dump_defaults(info.default_factory)
162+
continue
152163
default = info.default_factory()
153164
else:
154165
# Required field — recurse if BaseModel, else None

src/inference_endpoint/endpoint_client/config.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737

3838
from .accumulator_protocol import SSEAccumulatorProtocol
3939
from .adapter_protocol import HttpRequestAdapter
40-
from .cpu_affinity import AffinityPlan, get_cpus_in_numa_node, get_current_numa_node
40+
from .cpu_affinity import (
41+
AffinityPlan,
42+
UnsupportedPlatformError,
43+
get_cpus_in_numa_node,
44+
get_current_numa_node,
45+
)
4146
from .utils import get_ephemeral_port_limit, get_ephemeral_port_range
4247

4348
ADAPTER_MAP = {
@@ -262,17 +267,24 @@ def _get_auto_num_workers() -> int:
262267
Users can override with explicit num_workers to use more cores (workers
263268
will be pinned to additional cores outside NUMA domain if needed).
264269
270+
On non-Linux platforms (NUMA probing is Linux-only) falls back to
271+
``min_workers`` so the config can still be constructed for local
272+
development, template regeneration, and tests.
273+
265274
Returns:
266275
Number of workers to use when num_workers is -1 (auto).
267276
"""
268277
min_workers = 10
269278
max_workers = 24
270279

271-
numa_node = get_current_numa_node()
272-
if numa_node is None:
280+
try:
281+
numa_node = get_current_numa_node()
282+
if numa_node is None:
283+
return min_workers
284+
numa_cpus = get_cpus_in_numa_node(numa_node)
285+
except UnsupportedPlatformError:
273286
return min_workers
274287

275-
numa_cpus = get_cpus_in_numa_node(numa_node)
276288
if not numa_cpus:
277289
return min_workers
278290

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for scripts/regenerate_templates.py.
5+
6+
`_dump_defaults` must extract defaults without constructing nested
7+
BaseModels that appear as default_factory, because construction runs
8+
validators (which may have platform-dependent side effects).
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import importlib.util
14+
import sys
15+
from pathlib import Path
16+
17+
from pydantic import BaseModel, Field, model_validator
18+
19+
_REPO_ROOT = Path(__file__).resolve().parents[3]
20+
_SCRIPT = _REPO_ROOT / "scripts" / "regenerate_templates.py"
21+
22+
23+
def _load_regenerate_templates():
24+
"""Load scripts/regenerate_templates.py as a module (it is not a package)."""
25+
if "regenerate_templates" in sys.modules:
26+
return sys.modules["regenerate_templates"]
27+
spec = importlib.util.spec_from_file_location("regenerate_templates", _SCRIPT)
28+
assert spec and spec.loader
29+
module = importlib.util.module_from_spec(spec)
30+
sys.modules["regenerate_templates"] = module
31+
spec.loader.exec_module(module)
32+
return module
33+
34+
35+
class TestDumpDefaultsSkipsBaseModelFactory:
36+
def test_basemodel_factory_does_not_run_validator(self):
37+
"""default_factory=<BaseModel subclass> must not invoke the model's validators."""
38+
rt = _load_regenerate_templates()
39+
40+
call_count = 0
41+
42+
class Inner(BaseModel):
43+
x: int = 42
44+
45+
@model_validator(mode="after")
46+
def _count(self):
47+
nonlocal call_count
48+
call_count += 1
49+
return self
50+
51+
class Outer(BaseModel):
52+
inner: Inner = Field(default_factory=Inner)
53+
54+
# Sanity: constructing Inner() directly does invoke the validator.
55+
Inner()
56+
assert call_count == 1
57+
58+
call_count = 0
59+
result = rt._dump_defaults(Outer)
60+
61+
assert call_count == 0, (
62+
"Inner validator was invoked — _dump_defaults called the factory "
63+
"instead of recursing."
64+
)
65+
assert result == {"inner": {"x": 42}}
66+
67+
def test_callable_factory_is_still_invoked(self):
68+
"""Factories that are callables (not BaseModel subclasses) must still be called."""
69+
rt = _load_regenerate_templates()
70+
71+
class Config(BaseModel):
72+
tags: list[str] = Field(default_factory=lambda: ["default-tag"])
73+
74+
result = rt._dump_defaults(Config)
75+
assert result == {"tags": ["default-tag"]}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for HTTPClientConfig construction on non-Linux platforms.
5+
6+
NUMA probing is Linux-only; auto-detecting num_workers must fall back
7+
gracefully so HTTPClientConfig() can be constructed anywhere.
8+
"""
9+
10+
from unittest.mock import patch
11+
12+
from inference_endpoint.endpoint_client import config as cfg
13+
from inference_endpoint.endpoint_client.cpu_affinity import UnsupportedPlatformError
14+
15+
16+
class TestAutoNumWorkersNonLinux:
17+
def _clear_cache(self):
18+
cfg._get_auto_num_workers.cache_clear()
19+
20+
def test_get_current_numa_node_unsupported_falls_back_to_min(self):
21+
self._clear_cache()
22+
with patch.object(
23+
cfg, "get_current_numa_node", side_effect=UnsupportedPlatformError("darwin")
24+
):
25+
assert cfg._get_auto_num_workers() == 10
26+
27+
def test_get_cpus_in_numa_node_unsupported_falls_back_to_min(self):
28+
self._clear_cache()
29+
with (
30+
patch.object(cfg, "get_current_numa_node", return_value=0),
31+
patch.object(
32+
cfg,
33+
"get_cpus_in_numa_node",
34+
side_effect=UnsupportedPlatformError("darwin"),
35+
),
36+
):
37+
assert cfg._get_auto_num_workers() == 10
38+
39+
def test_http_client_config_constructs_when_numa_unsupported(self):
40+
self._clear_cache()
41+
with patch.object(
42+
cfg, "get_current_numa_node", side_effect=UnsupportedPlatformError("darwin")
43+
):
44+
c = cfg.HTTPClientConfig()
45+
assert c.num_workers == 10

0 commit comments

Comments
 (0)