Skip to content

Commit 2d38c88

Browse files
committed
fix: harden workflow chaining concurrency
1 parent cdb6a76 commit 2d38c88

8 files changed

Lines changed: 194 additions & 70 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/models/factory.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
if TYPE_CHECKING:
1414
from data_designer.config.run_config import RunConfig
1515
from data_designer.engine.mcp.registry import MCPRegistry
16+
from data_designer.engine.models.clients.throttle_manager import ThrottleManager
1617
from data_designer.engine.models.registry import ModelRegistry
1718

1819

@@ -24,6 +25,7 @@ def create_model_registry(
2425
mcp_registry: MCPRegistry | None = None,
2526
client_concurrency_mode: ClientConcurrencyMode = ClientConcurrencyMode.SYNC,
2627
run_config: RunConfig | None = None,
28+
throttle_manager: ThrottleManager | None = None,
2729
) -> ModelRegistry:
2830
"""Factory function for creating a ModelRegistry instance.
2931
@@ -43,6 +45,8 @@ def create_model_registry(
4345
run_config: Optional runtime configuration. The nested
4446
``run_config.throttle`` (a ``ThrottleConfig``) is forwarded to the
4547
``ThrottleManager`` constructor.
48+
throttle_manager: Optional shared throttle manager. When omitted, a new
49+
manager is created for this registry.
4650
4751
Returns:
4852
A configured ModelRegistry instance.
@@ -54,7 +58,8 @@ def create_model_registry(
5458
from data_designer.engine.models.facade import ModelFacade
5559
from data_designer.engine.models.registry import ModelRegistry
5660

57-
throttle_manager = ThrottleManager((run_config or RunConfig()).throttle)
61+
if throttle_manager is None:
62+
throttle_manager = ThrottleManager((run_config or RunConfig()).throttle)
5863

5964
def model_facade_factory(
6065
model_config: ModelConfig,

packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import os
7+
from typing import TYPE_CHECKING
78

89
from data_designer.config.base import ConfigBase
910
from data_designer.config.dataset_metadata import DatasetMetadata
@@ -26,6 +27,9 @@
2627
from data_designer.engine.secret_resolver import SecretResolver
2728
from data_designer.engine.storage.artifact_storage import ArtifactStorage
2829

30+
if TYPE_CHECKING:
31+
from data_designer.engine.models.clients.throttle_manager import ThrottleManager
32+
2933

3034
class ResourceType(StrEnum):
3135
PERSON_READER = "person_reader"
@@ -91,6 +95,7 @@ def create_resource_provider(
9195
mcp_providers: list[MCPProviderT] | None = None,
9296
tool_configs: list[ToolConfig] | None = None,
9397
client_concurrency_mode: ClientConcurrencyMode | None = None,
98+
throttle_manager: ThrottleManager | None = None,
9499
) -> ResourceProvider:
95100
"""Factory function for creating a ResourceProvider instance.
96101
@@ -111,6 +116,7 @@ def create_resource_provider(
111116
run_config: Optional runtime configuration.
112117
mcp_providers: Optional list of MCP provider configurations.
113118
tool_configs: Optional list of tool configurations.
119+
throttle_manager: Optional shared throttle manager for model clients.
114120
115121
Returns:
116122
A configured ResourceProvider instance.
@@ -158,6 +164,7 @@ def create_resource_provider(
158164
mcp_registry=mcp_registry,
159165
client_concurrency_mode=client_concurrency_mode,
160166
run_config=effective_run_config,
167+
throttle_manager=throttle_manager,
161168
),
162169
person_reader=person_reader,
163170
mcp_registry=mcp_registry,

packages/data-designer-engine/src/data_designer/engine/resources/seed_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
from abc import ABC, abstractmethod
88
from collections.abc import Callable, Iterable, Sequence
9+
from copy import copy
910
from dataclasses import dataclass
1011
from fnmatch import fnmatchcase
1112
from pathlib import Path, PurePosixPath
@@ -673,7 +674,7 @@ def add_reader(self, reader: SeedReader) -> Self:
673674
return self
674675

675676
def get_reader(self, seed_dataset_source: SeedSource, secret_resolver: SecretResolver) -> SeedReader:
676-
reader = self._get_reader_for_source(seed_dataset_source)
677+
reader = copy(self._get_reader_for_source(seed_dataset_source))
677678
reader.attach(seed_dataset_source, secret_resolver)
678679
return reader
679680

packages/data-designer-engine/tests/engine/resources/test_seed_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,9 @@ def test_get_reader_basic():
412412

413413
reader = registry.get_reader(local_seed_config, PlaintextResolver())
414414

415-
assert reader == df_reader
415+
assert isinstance(reader, DataFrameSeedReader)
416+
assert reader is not df_reader
417+
assert reader.source is local_seed_config
416418

417419

418420
def test_get_reader_missing():

packages/data-designer/src/data_designer/interface/composite_workflow.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import hashlib
77
import json
8+
import logging
89
import shutil
910
import time
10-
from collections.abc import Callable, Iterator
11-
from contextlib import contextmanager
11+
from collections.abc import Callable, ItemsView, Iterator, KeysView
1212
from dataclasses import dataclass
1313
from pathlib import Path
1414
from typing import TYPE_CHECKING, Any
@@ -24,9 +24,14 @@
2424
from data_designer.interface.results import DatasetCreationResults
2525

2626
if TYPE_CHECKING:
27+
import pandas as pd
28+
29+
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
2730
from data_designer.interface.data_designer import DataDesigner
2831

2932

33+
logger = logging.getLogger(__name__)
34+
3035
OnSuccessCallback = Callable[[Path], Path | str]
3136

3237

@@ -50,13 +55,21 @@ class SkippedStageResult:
5055

5156

5257
class CompositeWorkflowResults:
58+
"""Results for a composite workflow run.
59+
60+
Per-stage entries are the original ``DataDesigner.create()`` results. If a
61+
stage uses ``on_success``, metadata and downstream seeding use the callback
62+
output path while the stage result still points at the stage's generated
63+
dataset.
64+
"""
65+
5366
def __init__(
5467
self,
5568
*,
5669
name: str,
5770
stage_results: dict[str, DatasetCreationResults | SkippedStageResult],
5871
final_stage_name: str,
59-
):
72+
) -> None:
6073
self.name = name
6174
self.stage_results = stage_results
6275
self.final_stage_name = final_stage_name
@@ -67,10 +80,10 @@ def __getitem__(self, stage_name: str) -> DatasetCreationResults | SkippedStageR
6780
def __iter__(self) -> Iterator[str]:
6881
return iter(self.stage_results)
6982

70-
def keys(self):
83+
def keys(self) -> KeysView[str]:
7184
return self.stage_results.keys()
7285

73-
def items(self):
86+
def items(self) -> ItemsView[str, DatasetCreationResults | SkippedStageResult]:
7487
return self.stage_results.items()
7588

7689
@property
@@ -80,24 +93,24 @@ def final_result(self) -> DatasetCreationResults:
8093
raise DataDesignerWorkflowError(f"Final stage {self.final_stage_name!r} was skipped: {result.status}.")
8194
return result
8295

83-
def load_dataset(self):
96+
def load_dataset(self) -> pd.DataFrame:
8497
return self.final_result.load_dataset()
8598

86-
def load_analysis(self):
99+
def load_analysis(self) -> DatasetProfilerResults:
87100
return self.final_result.load_analysis()
88101

89102
def count_records(self) -> int:
90103
return self.final_result.count_records()
91104

92-
def export(self, *args, **kwargs):
105+
def export(self, *args: Any, **kwargs: Any) -> Path:
93106
return self.final_result.export(*args, **kwargs)
94107

95-
def push_to_hub(self, *args, **kwargs):
108+
def push_to_hub(self, *args: Any, **kwargs: Any) -> str:
96109
return self.final_result.push_to_hub(*args, **kwargs)
97110

98111

99112
class CompositeWorkflow:
100-
def __init__(self, *, name: str, data_designer: DataDesigner):
113+
def __init__(self, *, name: str, data_designer: DataDesigner) -> None:
101114
_validate_dir_name(name, "workflow name")
102115
self.name = name
103116
self._data_designer = data_designer
@@ -123,7 +136,7 @@ def add_stage(
123136
self._stages.append(
124137
_WorkflowStage(
125138
name=name,
126-
config_builder=config_builder,
139+
config_builder=_clone_config_builder(config_builder),
127140
depends_on=(self._stages[-1].name,) if self._stages else (),
128141
num_records=num_records,
129142
on_success=on_success,
@@ -136,6 +149,7 @@ def add_stage(
136149
return self
137150

138151
def run(self) -> CompositeWorkflowResults:
152+
"""Run all stages from scratch, replacing deterministic stage directories."""
139153
if not self._stages:
140154
raise DataDesignerWorkflowError(f"Workflow {self.name!r} has no stages.")
141155

@@ -174,6 +188,12 @@ def run(self) -> CompositeWorkflowResults:
174188

175189
stage_builder = _clone_config_builder(stage.config_builder)
176190
if previous_seed_path is not None:
191+
if stage_builder.get_seed_config() is not None:
192+
logger.warning(
193+
"Stage %r has a seed dataset; workflow will seed it from upstream stage %r.",
194+
stage.name,
195+
previous_stage_name,
196+
)
177197
stage_builder.with_seed_dataset(
178198
_local_seed_source_from_path(previous_seed_path),
179199
sampling_strategy=stage.sampling_strategy,
@@ -206,12 +226,12 @@ def run(self) -> CompositeWorkflowResults:
206226

207227
start_time = time.monotonic()
208228
try:
209-
with _temporary_artifact_path(self._data_designer, workflow_path):
210-
result = self._data_designer.create(
211-
stage_builder,
212-
num_records=num_records,
213-
dataset_name=stage_dir_name,
214-
)
229+
result = self._data_designer.create(
230+
stage_builder,
231+
num_records=num_records,
232+
dataset_name=stage_dir_name,
233+
artifact_path=workflow_path,
234+
)
215235
actual_records = result.count_records()
216236
output_seed_path = result.artifact_storage.final_dataset_path
217237
callback_output_path = None
@@ -263,16 +283,6 @@ def _clone_config_builder(config_builder: DataDesignerConfigBuilder) -> DataDesi
263283
return DataDesignerConfigBuilder.from_config(BuilderConfig(data_designer=config_builder.build()))
264284

265285

266-
@contextmanager
267-
def _temporary_artifact_path(data_designer: DataDesigner, artifact_path: Path):
268-
original_artifact_path = data_designer._artifact_path
269-
data_designer._artifact_path = artifact_path
270-
try:
271-
yield
272-
finally:
273-
data_designer._artifact_path = original_artifact_path
274-
275-
276286
def _stage_dir_name(index: int, name: str) -> str:
277287
return f"stage-{index}-{name}"
278288

0 commit comments

Comments
 (0)