Skip to content

Commit 5ec0e89

Browse files
authored
fix(interface): don't leak YAML default provider into user-supplied list (#591)
* fix(interface): don't leak YAML default provider into user-supplied list When a caller passes ``model_providers`` to ``DataDesigner.__init__``, the YAML's ``default:`` key from ``~/.data-designer/model_providers.yaml`` was still being applied to the resulting ``ModelProviderRegistry``. This caused two related problems: 1. Hard failure: if the YAML default named a provider absent from the user-supplied list, construction raised ``ValidationError: Specified default 'X' not found in providers list``. 2. Silent override: if the YAML default matched a non-first user-supplied provider, the documented "first wins" behavior was silently overridden. Gate the YAML lookup on ``model_providers is None`` so that user-supplied lists own their own default. Also expose ``model_provider_registry`` and ``run_config`` as public read-only properties on ``DataDesigner``, paired with the existing ``secret_resolver`` property and ``set_run_config`` setter; tests now use these instead of the underscore-prefixed attributes. Closes #588. Made-with: Cursor * fix(interface): tighten model_provider_registry docstring; pin YAML-fallback path Address review feedback on PR #591: - Clarify ``model_provider_registry`` docstring so it reflects the full fallback chain: user-supplied first → YAML default (when set) → first provider in the YAML list. - Add ``test_init_no_user_providers_uses_yaml_default`` to lock the YAML-fallback contract that the #588 fix preserved but didn't pin. Made-with: Cursor
1 parent 81033e6 commit 5ec0e89

2 files changed

Lines changed: 165 additions & 39 deletions

File tree

packages/data-designer/src/data_designer/interface/data_designer.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from data_designer.engine.compiler import compile_data_designer_config
3838
from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder
3939
from data_designer.engine.mcp.io import list_tool_names
40-
from data_designer.engine.model_provider import resolve_model_provider_registry
40+
from data_designer.engine.model_provider import ModelProviderRegistry, resolve_model_provider_registry
4141
from data_designer.engine.resources.person_reader import (
4242
PersonReader,
4343
create_person_reader,
@@ -150,11 +150,20 @@ def __init__(
150150
self._run_config = RunConfig()
151151
self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH)
152152
self._person_reader = person_reader
153-
self._model_providers = self._resolve_model_providers(model_providers)
153+
# Only consult the YAML's `default:` key when we are also falling back to
154+
# the YAML's `providers:` list. A user-supplied `model_providers` list
155+
# owns its own default (first wins), so the YAML default must not leak
156+
# in and either (a) hard-fail validation when the YAML names a provider
157+
# absent from the supplied list or (b) silently override the
158+
# documented first-wins ordering. See issue #588.
159+
if model_providers is None:
160+
self._model_providers = self._resolve_model_providers(None)
161+
default_provider_name = get_default_provider_name()
162+
else:
163+
self._model_providers = self._resolve_model_providers(model_providers)
164+
default_provider_name = None
154165
self._mcp_providers = mcp_providers or []
155-
self._model_provider_registry = resolve_model_provider_registry(
156-
self._model_providers, get_default_provider_name()
157-
)
166+
self._model_provider_registry = resolve_model_provider_registry(self._model_providers, default_provider_name)
158167
self._seed_reader_registry = SeedReaderRegistry(readers=seed_readers or DEFAULT_SEED_READERS)
159168

160169
@property
@@ -423,6 +432,32 @@ def secret_resolver(self) -> SecretResolver:
423432
"""
424433
return self._secret_resolver
425434

435+
@property
436+
def model_provider_registry(self) -> ModelProviderRegistry:
437+
"""Get the resolved model provider registry.
438+
439+
Returns:
440+
The ModelProviderRegistry containing the providers and default
441+
resolved at construction time. The default is taken from the
442+
first user-supplied provider when ``model_providers`` was passed
443+
to the constructor; otherwise from the YAML's ``default:`` key
444+
when set, falling back to the first provider in the YAML list.
445+
"""
446+
return self._model_provider_registry
447+
448+
@property
449+
def run_config(self) -> RunConfig:
450+
"""Get the runtime configuration applied to dataset generation.
451+
452+
Returns:
453+
The active RunConfig instance. Note that ``RunConfig`` normalizes
454+
some fields on construction (e.g., ``shutdown_error_rate`` becomes
455+
``1.0`` when ``disable_early_shutdown=True``), so the returned
456+
object may not exactly equal the one originally passed to
457+
``set_run_config``.
458+
"""
459+
return self._run_config
460+
426461
def set_run_config(self, run_config: RunConfig) -> None:
427462
"""Set the runtime configuration for dataset generation.
428463

packages/data-designer/tests/interface/test_data_designer.py

Lines changed: 125 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -420,17 +420,110 @@ def test_init_with_path_object(stub_artifact_path, stub_model_providers):
420420
assert designer is not None
421421

422422

423+
def test_init_user_supplied_providers_ignore_unrelated_yaml_default(
424+
stub_artifact_path: Path,
425+
stub_model_providers: list[ModelProvider],
426+
stub_managed_assets_path: Path,
427+
) -> None:
428+
"""Regression for #588: a YAML ``default:`` that names a provider absent
429+
from a user-supplied ``model_providers`` list must not leak into
430+
construction.
431+
432+
Pre-fix this raised ``ValidationError: Specified default 'unrelated' not
433+
found in providers list``.
434+
"""
435+
with patch.object(dd_mod, "get_default_provider_name", return_value="unrelated"):
436+
data_designer = DataDesigner(
437+
artifact_path=stub_artifact_path,
438+
model_providers=stub_model_providers,
439+
secret_resolver=PlaintextResolver(),
440+
managed_assets_path=stub_managed_assets_path,
441+
)
442+
443+
assert data_designer.model_provider_registry.get_default_provider_name() == "stub-model-provider"
444+
445+
446+
def test_init_user_supplied_providers_preserve_first_wins_over_yaml_default(
447+
stub_artifact_path: Path,
448+
stub_managed_assets_path: Path,
449+
) -> None:
450+
"""Regression for #588: when the YAML ``default:`` matches a user-supplied
451+
provider that isn't first in the list, the documented ``model_providers[0]``
452+
"first wins" behavior must not be silently overridden.
453+
"""
454+
user_providers = [
455+
ModelProvider(
456+
name="first-provider",
457+
endpoint="https://first.example.com/v1",
458+
api_key="FIRST_API_KEY",
459+
),
460+
ModelProvider(
461+
name="second-provider",
462+
endpoint="https://second.example.com/v1",
463+
api_key="SECOND_API_KEY",
464+
),
465+
]
466+
467+
with patch.object(dd_mod, "get_default_provider_name", return_value="second-provider"):
468+
data_designer = DataDesigner(
469+
artifact_path=stub_artifact_path,
470+
model_providers=user_providers,
471+
secret_resolver=PlaintextResolver(),
472+
managed_assets_path=stub_managed_assets_path,
473+
)
474+
475+
assert data_designer.model_provider_registry.get_default_provider_name() == "first-provider"
476+
477+
478+
def test_init_no_user_providers_uses_yaml_default(
479+
stub_artifact_path: Path,
480+
stub_managed_assets_path: Path,
481+
) -> None:
482+
"""Pin the unchanged YAML-fallback path: when the caller omits
483+
``model_providers``, DataDesigner consults both ``providers:`` and
484+
``default:`` from the YAML.
485+
486+
The fix in #588 only changes the user-supplied branch; this test locks the
487+
YAML-fallback branch's contract so a future refactor can't silently regress
488+
it.
489+
"""
490+
yaml_providers = [
491+
ModelProvider(
492+
name="yaml-first",
493+
endpoint="https://yaml-first.example.com/v1",
494+
api_key="yaml-first-key",
495+
),
496+
ModelProvider(
497+
name="yaml-second",
498+
endpoint="https://yaml-second.example.com/v1",
499+
api_key="yaml-second-key",
500+
),
501+
]
502+
503+
with (
504+
patch.object(dd_mod, "get_default_providers", return_value=yaml_providers),
505+
patch.object(dd_mod, "get_default_provider_name", return_value="yaml-second"),
506+
):
507+
data_designer = DataDesigner(
508+
artifact_path=stub_artifact_path,
509+
secret_resolver=PlaintextResolver(),
510+
managed_assets_path=stub_managed_assets_path,
511+
)
512+
513+
assert data_designer.model_provider_registry.get_default_provider_name() == "yaml-second"
514+
515+
423516
def test_run_config_setting_persists(stub_artifact_path, stub_model_providers):
424517
"""Test that run config setting persists across multiple calls."""
425518
data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers)
426519

427520
# Test default values
428-
assert data_designer._run_config.disable_early_shutdown is False
429-
assert data_designer._run_config.shutdown_error_rate == 0.5
430-
assert data_designer._run_config.shutdown_error_window == 10
431-
assert data_designer._run_config.buffer_size == 1000
432-
assert data_designer._run_config.max_conversation_restarts == 5
433-
assert data_designer._run_config.max_conversation_correction_steps == 0
521+
assert data_designer.run_config.disable_early_shutdown is False
522+
assert data_designer.run_config.shutdown_error_rate == 0.5
523+
assert data_designer.run_config.shutdown_error_window == 10
524+
assert data_designer.run_config.buffer_size == 1000
525+
assert data_designer.run_config.max_conversation_restarts == 5
526+
assert data_designer.run_config.max_conversation_correction_steps == 0
434527

435528
# Test setting custom values
436529
data_designer.set_run_config(
@@ -443,12 +536,12 @@ def test_run_config_setting_persists(stub_artifact_path, stub_model_providers):
443536
max_conversation_correction_steps=2,
444537
)
445538
)
446-
assert data_designer._run_config.disable_early_shutdown is True
447-
assert data_designer._run_config.shutdown_error_rate == 1.0 # normalized when disabled
448-
assert data_designer._run_config.shutdown_error_window == 25
449-
assert data_designer._run_config.buffer_size == 500
450-
assert data_designer._run_config.max_conversation_restarts == 7
451-
assert data_designer._run_config.max_conversation_correction_steps == 2
539+
assert data_designer.run_config.disable_early_shutdown is True
540+
assert data_designer.run_config.shutdown_error_rate == 1.0 # normalized when disabled
541+
assert data_designer.run_config.shutdown_error_window == 25
542+
assert data_designer.run_config.buffer_size == 500
543+
assert data_designer.run_config.max_conversation_restarts == 7
544+
assert data_designer.run_config.max_conversation_correction_steps == 2
452545

453546
# Test updating values
454547
data_designer.set_run_config(
@@ -461,12 +554,12 @@ def test_run_config_setting_persists(stub_artifact_path, stub_model_providers):
461554
max_conversation_correction_steps=1,
462555
)
463556
)
464-
assert data_designer._run_config.disable_early_shutdown is False
465-
assert data_designer._run_config.shutdown_error_rate == 0.3
466-
assert data_designer._run_config.shutdown_error_window == 5
467-
assert data_designer._run_config.buffer_size == 750
468-
assert data_designer._run_config.max_conversation_restarts == 9
469-
assert data_designer._run_config.max_conversation_correction_steps == 1
557+
assert data_designer.run_config.disable_early_shutdown is False
558+
assert data_designer.run_config.shutdown_error_rate == 0.3
559+
assert data_designer.run_config.shutdown_error_window == 5
560+
assert data_designer.run_config.buffer_size == 750
561+
assert data_designer.run_config.max_conversation_restarts == 9
562+
assert data_designer.run_config.max_conversation_correction_steps == 1
470563

471564

472565
def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub_model_providers):
@@ -480,7 +573,7 @@ def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub
480573
shutdown_error_rate=0.7,
481574
)
482575
)
483-
assert data_designer._run_config.shutdown_error_rate == 0.7
576+
assert data_designer.run_config.shutdown_error_rate == 0.7
484577

485578
# When disabled, shutdown_error_rate should be normalized to 1.0
486579
data_designer.set_run_config(
@@ -489,7 +582,7 @@ def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub
489582
shutdown_error_rate=0.7,
490583
)
491584
)
492-
assert data_designer._run_config.shutdown_error_rate == 1.0
585+
assert data_designer.run_config.shutdown_error_rate == 1.0
493586

494587

495588
def test_run_config_rejects_invalid_buffer_size() -> None:
@@ -858,13 +951,12 @@ def test_create_logs_secure_jinja_rendering_mode(
858951
stub_sampler_only_config_builder: DataDesignerConfigBuilder,
859952
stub_managed_assets_path: Path,
860953
) -> None:
861-
with patch.object(dd_mod, "get_default_provider_name", return_value="stub-model-provider"):
862-
data_designer = DataDesigner(
863-
artifact_path=stub_artifact_path,
864-
model_providers=stub_model_providers,
865-
secret_resolver=PlaintextResolver(),
866-
managed_assets_path=stub_managed_assets_path,
867-
)
954+
data_designer = DataDesigner(
955+
artifact_path=stub_artifact_path,
956+
model_providers=stub_model_providers,
957+
secret_resolver=PlaintextResolver(),
958+
managed_assets_path=stub_managed_assets_path,
959+
)
868960
data_designer.set_run_config(RunConfig(jinja_rendering_engine=JinjaRenderingEngine.SECURE))
869961

870962
with (
@@ -898,13 +990,12 @@ def test_preview_logs_native_jinja_rendering_mode(
898990
stub_sampler_only_config_builder: DataDesignerConfigBuilder,
899991
stub_managed_assets_path: Path,
900992
) -> None:
901-
with patch.object(dd_mod, "get_default_provider_name", return_value="stub-model-provider"):
902-
data_designer = DataDesigner(
903-
artifact_path=stub_artifact_path,
904-
model_providers=stub_model_providers,
905-
secret_resolver=PlaintextResolver(),
906-
managed_assets_path=stub_managed_assets_path,
907-
)
993+
data_designer = DataDesigner(
994+
artifact_path=stub_artifact_path,
995+
model_providers=stub_model_providers,
996+
secret_resolver=PlaintextResolver(),
997+
managed_assets_path=stub_managed_assets_path,
998+
)
908999
data_designer.set_run_config(RunConfig(jinja_rendering_engine=JinjaRenderingEngine.NATIVE))
9091000

9101001
with (

0 commit comments

Comments
 (0)