Skip to content

Commit b94b88b

Browse files
authored
feat(cli): bootstrap default configs on CLI startup (#401)
* feat(cli): bootstrap default configs on command run * fix(cli): use active interpreter in bootstrap warning * refactor(cli): simplify bootstrap warning flow * refactor(cli): bootstrap defaults in main entrypoint * refactor(cli): keep bootstrap ownership in main * test(cli): cover lazy dispatch and runtime failure flag * refactor(cli): remove redundant bootstrap state * test(cli): assert bootstrap warning includes error * test: address cli bootstrap review feedback
1 parent eac63a1 commit b94b88b

7 files changed

Lines changed: 112 additions & 29 deletions

File tree

packages/data-designer/src/data_designer/cli/lazy_group.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ def make_context(
6060
) -> click.Context:
6161
return self._resolve().make_context(info_name, args, parent, **extra)
6262

63-
def invoke(self, ctx: click.Context) -> Any:
64-
return self._resolve().invoke(ctx)
65-
6663

6764
def create_lazy_typer_group(
6865
lazy_subcommands: dict[str, dict[str, str]],

packages/data-designer/src/data_designer/cli/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import typer
77

88
from data_designer.cli.lazy_group import create_lazy_typer_group
9+
from data_designer.cli.runtime import ensure_cli_default_model_settings
910

1011
_CMD = "data_designer.cli.commands"
1112

@@ -105,6 +106,7 @@
105106

106107
def main() -> None:
107108
"""Main entry point for the CLI."""
109+
ensure_cli_default_model_settings()
108110
app()
109111

110112

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from data_designer.cli.ui import print_warning
7+
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
8+
9+
10+
def ensure_cli_default_model_settings() -> None:
11+
"""Best-effort bootstrap for CLI default model settings.
12+
13+
Repeated calls are safe because ``resolve_seed_default_model_settings()``
14+
only writes missing files/directories.
15+
"""
16+
try:
17+
resolve_seed_default_model_settings()
18+
except Exception as e:
19+
print_warning(
20+
"Could not initialize default model providers and model configs automatically. "
21+
f"The command will continue. Error: {e}. "
22+
"You will need to configure providers and models manually with "
23+
"`data-designer config providers` and `data-designer config models`."
24+
)

packages/data-designer/src/data_designer/cli/utils/config_loader.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from urllib.parse import urlparse
1010

1111
from data_designer.config.config_builder import DataDesignerConfigBuilder
12-
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
1312
from data_designer.config.utils.io_helpers import VALID_CONFIG_FILE_EXTENSIONS, is_http_url
1413

1514

@@ -23,18 +22,6 @@ class ConfigLoadError(Exception):
2322
USER_MODULE_FUNC_NAME = "load_config_builder"
2423

2524

26-
_default_settings_initialized = False
27-
28-
29-
def _ensure_default_model_settings() -> None:
30-
"""Initialize default model/provider files once before loading CLI configs."""
31-
global _default_settings_initialized
32-
if _default_settings_initialized:
33-
return
34-
resolve_seed_default_model_settings()
35-
_default_settings_initialized = True
36-
37-
3825
def load_config_builder(config_source: str) -> DataDesignerConfigBuilder:
3926
"""Load a DataDesignerConfigBuilder from a file path or URL.
4027
@@ -52,8 +39,6 @@ def load_config_builder(config_source: str) -> DataDesignerConfigBuilder:
5239
Raises:
5340
ConfigLoadError: If the file cannot be loaded or is invalid.
5441
"""
55-
_ensure_default_model_settings()
56-
5742
if is_http_url(config_source):
5843
return _load_from_config_url(config_source)
5944

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from unittest.mock import Mock, call, patch
7+
8+
from typer.testing import CliRunner
9+
10+
from data_designer.cli.main import app, main
11+
from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS
12+
13+
runner = CliRunner()
14+
15+
16+
@patch("data_designer.cli.main.app")
17+
@patch("data_designer.cli.main.ensure_cli_default_model_settings")
18+
def test_main_bootstraps_before_running_app(mock_bootstrap: Mock, mock_app: Mock) -> None:
19+
"""The CLI entrypoint bootstraps defaults before invoking Typer."""
20+
call_order = Mock()
21+
call_order.attach_mock(mock_bootstrap, "bootstrap")
22+
call_order.attach_mock(mock_app, "app")
23+
24+
main()
25+
26+
assert call_order.mock_calls == [call.bootstrap(), call.app()]
27+
28+
29+
@patch("data_designer.cli.commands.create.GenerationController")
30+
def test_app_dispatches_lazy_create_command(mock_controller_cls: Mock) -> None:
31+
"""The Typer app dispatches lazy-loaded commands through the resolved callback."""
32+
mock_controller = Mock()
33+
mock_controller_cls.return_value = mock_controller
34+
35+
result = runner.invoke(app, ["create", "config.yaml"])
36+
37+
assert result.exit_code == 0
38+
mock_controller.run_create.assert_called_once_with(
39+
config_source="config.yaml",
40+
num_records=DEFAULT_NUM_RECORDS,
41+
dataset_name="dataset",
42+
artifact_path=None,
43+
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from unittest.mock import patch
7+
8+
import data_designer.cli.runtime as runtime_mod
9+
10+
11+
def test_ensure_cli_default_model_settings_attempts_default_setup() -> None:
12+
"""CLI bootstrap delegates to default setup when the CLI starts."""
13+
with (
14+
patch("data_designer.cli.runtime.print_warning") as mock_print_warning,
15+
patch("data_designer.cli.runtime.resolve_seed_default_model_settings") as mock_resolve,
16+
):
17+
runtime_mod.ensure_cli_default_model_settings()
18+
19+
mock_resolve.assert_called_once_with()
20+
mock_print_warning.assert_not_called()
21+
22+
23+
def test_ensure_cli_default_model_settings_warns_and_continues() -> None:
24+
"""CLI bootstrap prints an actionable warning when setup fails."""
25+
with (
26+
patch("data_designer.cli.runtime.print_warning") as mock_print_warning,
27+
patch(
28+
"data_designer.cli.runtime.resolve_seed_default_model_settings",
29+
side_effect=RuntimeError("boom"),
30+
) as mock_resolve,
31+
):
32+
runtime_mod.ensure_cli_default_model_settings()
33+
34+
mock_resolve.assert_called_once_with()
35+
mock_print_warning.assert_called_once()
36+
warning = mock_print_warning.call_args[0][0]
37+
assert "Could not initialize default model providers and model configs automatically." in warning
38+
assert "The command will continue." in warning
39+
assert "boom" in warning
40+
assert "data-designer config providers" in warning
41+
assert "data-designer config models" in warning

packages/data-designer/tests/cli/utils/test_config_loader.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from __future__ import annotations
5+
46
from pathlib import Path
57
from unittest.mock import MagicMock, patch
68

79
import pytest
810

9-
import data_designer.cli.utils.config_loader as config_loader_mod
1011
from data_designer.cli.utils.config_loader import (
1112
ConfigLoadError,
1213
load_config_builder,
@@ -288,13 +289,3 @@ def test_load_config_builder_empty_yaml(tmp_path: Path) -> None:
288289

289290
with pytest.raises(ConfigLoadError, match="Failed to load config from"):
290291
load_config_builder(str(yaml_file))
291-
292-
293-
def test_ensure_default_model_settings_runs_once(monkeypatch: pytest.MonkeyPatch) -> None:
294-
"""_ensure_default_model_settings only calls resolve_seed_default_model_settings once."""
295-
monkeypatch.setattr(config_loader_mod, "_default_settings_initialized", False)
296-
297-
with patch("data_designer.cli.utils.config_loader.resolve_seed_default_model_settings") as mock_resolve:
298-
config_loader_mod._ensure_default_model_settings()
299-
config_loader_mod._ensure_default_model_settings()
300-
mock_resolve.assert_called_once()

0 commit comments

Comments
 (0)