Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions scripts/regenerate_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ def _dump_defaults(model: type[BaseModel]) -> dict:
if info.default is not PydanticUndefined:
default = info.default
elif info.default_factory is not None:
# If the factory is itself a BaseModel subclass (e.g.
# default_factory=HTTPClientConfig), recurse into it instead of
# calling it — calling would run validators, defeating the point
# of this function. Factories that dynamically pick a concrete
# subclass (e.g. TransportConfig.create_default → ZMQTransportConfig)
# aren't types, so they fall through and get called as before.
if isinstance(info.default_factory, type) and issubclass(
info.default_factory, BaseModel
):
out[name] = _dump_defaults(info.default_factory)
continue
default = info.default_factory()
Comment thread
viraatc marked this conversation as resolved.
else:
# Required field — recurse if BaseModel, else None
Expand Down
20 changes: 16 additions & 4 deletions src/inference_endpoint/endpoint_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@

from .accumulator_protocol import SSEAccumulatorProtocol
from .adapter_protocol import HttpRequestAdapter
from .cpu_affinity import AffinityPlan, get_cpus_in_numa_node, get_current_numa_node
from .cpu_affinity import (
AffinityPlan,
UnsupportedPlatformError,
get_cpus_in_numa_node,
get_current_numa_node,
)
from .utils import get_ephemeral_port_limit, get_ephemeral_port_range

ADAPTER_MAP = {
Expand Down Expand Up @@ -262,17 +267,24 @@ def _get_auto_num_workers() -> int:
Users can override with explicit num_workers to use more cores (workers
will be pinned to additional cores outside NUMA domain if needed).

On non-Linux platforms (NUMA probing is Linux-only) falls back to
``min_workers`` so the config can still be constructed for local
development, template regeneration, and tests.

Returns:
Number of workers to use when num_workers is -1 (auto).
"""
min_workers = 10
max_workers = 24

numa_node = get_current_numa_node()
if numa_node is None:
try:
numa_node = get_current_numa_node()
if numa_node is None:
return min_workers
numa_cpus = get_cpus_in_numa_node(numa_node)
except UnsupportedPlatformError:
return min_workers
Comment thread
viraatc marked this conversation as resolved.

numa_cpus = get_cpus_in_numa_node(numa_node)
if not numa_cpus:
return min_workers

Expand Down
75 changes: 75 additions & 0 deletions tests/unit/config/test_regenerate_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Tests for scripts/regenerate_templates.py.

`_dump_defaults` must extract defaults without constructing nested
BaseModels that appear as default_factory, because construction runs
validators (which may have platform-dependent side effects).
"""

from __future__ import annotations

import importlib.util
import sys
from pathlib import Path

from pydantic import BaseModel, Field, model_validator

_REPO_ROOT = Path(__file__).resolve().parents[3]
_SCRIPT = _REPO_ROOT / "scripts" / "regenerate_templates.py"


def _load_regenerate_templates():
"""Load scripts/regenerate_templates.py as a module (it is not a package)."""
if "regenerate_templates" in sys.modules:
return sys.modules["regenerate_templates"]
spec = importlib.util.spec_from_file_location("regenerate_templates", _SCRIPT)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules["regenerate_templates"] = module
spec.loader.exec_module(module)
return module


class TestDumpDefaultsSkipsBaseModelFactory:
def test_basemodel_factory_does_not_run_validator(self):
"""default_factory=<BaseModel subclass> must not invoke the model's validators."""
rt = _load_regenerate_templates()

call_count = 0

class Inner(BaseModel):
x: int = 42

@model_validator(mode="after")
def _count(self):
nonlocal call_count
call_count += 1
return self

class Outer(BaseModel):
inner: Inner = Field(default_factory=Inner)

# Sanity: constructing Inner() directly does invoke the validator.
Inner()
assert call_count == 1

call_count = 0
result = rt._dump_defaults(Outer)

assert call_count == 0, (
"Inner validator was invoked — _dump_defaults called the factory "
"instead of recursing."
)
assert result == {"inner": {"x": 42}}

def test_callable_factory_is_still_invoked(self):
"""Factories that are callables (not BaseModel subclasses) must still be called."""
rt = _load_regenerate_templates()

class Config(BaseModel):
tags: list[str] = Field(default_factory=lambda: ["default-tag"])

result = rt._dump_defaults(Config)
assert result == {"tags": ["default-tag"]}
45 changes: 45 additions & 0 deletions tests/unit/endpoint_client/test_http_client_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Tests for HTTPClientConfig construction on non-Linux platforms.

NUMA probing is Linux-only; auto-detecting num_workers must fall back
gracefully so HTTPClientConfig() can be constructed anywhere.
"""

from unittest.mock import patch

from inference_endpoint.endpoint_client import config as cfg
from inference_endpoint.endpoint_client.cpu_affinity import UnsupportedPlatformError


class TestAutoNumWorkersNonLinux:
def _clear_cache(self):
cfg._get_auto_num_workers.cache_clear()

def test_get_current_numa_node_unsupported_falls_back_to_min(self):
self._clear_cache()
with patch.object(
cfg, "get_current_numa_node", side_effect=UnsupportedPlatformError("darwin")
):
assert cfg._get_auto_num_workers() == 10

def test_get_cpus_in_numa_node_unsupported_falls_back_to_min(self):
self._clear_cache()
with (
patch.object(cfg, "get_current_numa_node", return_value=0),
patch.object(
cfg,
"get_cpus_in_numa_node",
side_effect=UnsupportedPlatformError("darwin"),
),
):
assert cfg._get_auto_num_workers() == 10

def test_http_client_config_constructs_when_numa_unsupported(self):
self._clear_cache()
with patch.object(
cfg, "get_current_numa_node", side_effect=UnsupportedPlatformError("darwin")
):
c = cfg.HTTPClientConfig()
assert c.num_workers == 10
Loading