Skip to content

Commit ae70ca8

Browse files
committed
fix: address workflow review nits
1 parent 4dc1e68 commit ae70ca8

5 files changed

Lines changed: 44 additions & 10 deletions

File tree

packages/data-designer/src/data_designer/interface/composite_workflow.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def items(self) -> ItemsView[str, DatasetCreationResults | SkippedStageResult]:
107107

108108
@property
109109
def final_result(self) -> DatasetCreationResults:
110+
"""Return the final stage result, or raise if it was skipped."""
110111
return self._require_final_result()
111112

112113
def _require_final_result(self) -> DatasetCreationResults:
@@ -118,33 +119,41 @@ def _require_final_result(self) -> DatasetCreationResults:
118119
return result
119120

120121
def load_dataset(self) -> pd.DataFrame:
122+
"""Load the selected output from the final workflow stage."""
121123
self._require_final_result()
122124
return self.load_stage_output(self.final_stage_name)
123125

124126
def load_analysis(self) -> DatasetProfilerResults:
127+
"""Load analysis from the final stage result."""
125128
return self.final_result.load_analysis()
126129

127130
def count_records(self) -> int:
131+
"""Count records in the selected output from the final workflow stage."""
128132
self._require_final_result()
129133
return self.count_stage_output_records(self.final_stage_name)
130134

131135
def get_stage_output_path(self, stage_name: str) -> Path:
136+
"""Return the selected output path handed downstream for a stage."""
132137
result = self.stage_results[stage_name]
133138
if isinstance(result, SkippedStageResult):
134139
raise DataDesignerWorkflowError(f"Stage {stage_name!r} was skipped: {result.status.value}.")
135140
return self._stage_output_paths.get(stage_name, result.artifact_storage.final_dataset_path)
136141

137142
def load_stage_output(self, stage_name: str) -> pd.DataFrame:
143+
"""Load the selected output handed downstream for a stage."""
138144
return _load_parquet_dataset(self.get_stage_output_path(stage_name))
139145

140146
def count_stage_output_records(self, stage_name: str) -> int:
147+
"""Count records in the selected output handed downstream for a stage."""
141148
return _count_parquet_records(self.get_stage_output_path(stage_name))
142149

143150
def export(self, path: Path | str, *, format: ExportFormat | None = None) -> Path:
151+
"""Export the selected output from the final workflow stage."""
144152
self._require_final_result()
145153
return _export_parquet_dataset(self.get_stage_output_path(self.final_stage_name), Path(path), format=format)
146154

147155
def push_to_hub(self, *args: Any, **kwargs: Any) -> str:
156+
"""Push the final stage result to Hugging Face Hub when no output override is selected."""
148157
final_result = self.final_result
149158
if self.get_stage_output_path(self.final_stage_name) != final_result.artifact_storage.final_dataset_path:
150159
raise DataDesignerWorkflowError(
@@ -155,7 +164,10 @@ def push_to_hub(self, *args: Any, **kwargs: Any) -> str:
155164

156165

157166
class CompositeWorkflow:
167+
"""Experimental linear workflow for chaining Data Designer stages."""
168+
158169
def __init__(self, *, name: str, data_designer: DataDesigner) -> None:
170+
"""Create a workflow bound to a parent Data Designer instance."""
159171
_validate_dir_name(name, "workflow name")
160172
self.name = name
161173
self._data_designer = data_designer
@@ -210,11 +222,16 @@ def add_stage(
210222
return self
211223

212224
def run(self) -> CompositeWorkflowResults:
213-
"""Run all stages from scratch, replacing deterministic stage directories."""
225+
"""Run all stages from scratch.
226+
227+
Each stage writes a deterministic artifact directory under the parent
228+
Data Designer artifact path. Downstream stages are seeded from the
229+
selected output of the previous stage.
230+
"""
214231
if not self._stages:
215232
raise DataDesignerWorkflowError(f"Workflow {self.name!r} has no stages.")
216233

217-
workflow_path = self._data_designer._artifact_path / self.name
234+
workflow_path = self._data_designer.artifact_path / self.name
218235
workflow_path.mkdir(parents=True, exist_ok=True)
219236
metadata: dict[str, Any] = {
220237
"name": self.name,

packages/data-designer/src/data_designer/interface/data_designer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ def info(self) -> InterfaceInfo:
215215
"""
216216
return self._get_interface_info(self._model_providers)
217217

218+
@property
219+
def artifact_path(self) -> Path:
220+
"""Directory where Data Designer writes artifacts by default."""
221+
return self._artifact_path
222+
218223
def list_mcp_tool_names(self, mcp_provider_name: str, *, timeout_sec: float = 10.0) -> list[str]:
219224
"""Connect to a configured MCP provider and return the names of its available tools.
220225
@@ -498,6 +503,17 @@ def preview(
498503
)
499504

500505
def compose_workflow(self, *, name: str) -> CompositeWorkflow:
506+
"""Create an experimental composite workflow.
507+
508+
Workflow chaining is experimental and its API, metadata schema, and
509+
artifact layout may change in future releases.
510+
511+
Args:
512+
name: Workflow name used for the artifact directory.
513+
514+
Returns:
515+
A composite workflow that can run named stages in sequence.
516+
"""
501517
return CompositeWorkflow(name=name, data_designer=self)
502518

503519
def _log_jinja_rendering_engine_mode(self) -> None:

packages/data-designer/tests/interface/test_acreate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def fake_create(
9797
started_count += 1
9898
if started_count == 2:
9999
both_started.set()
100-
assert both_started.wait(2)
101-
assert release.wait(2)
100+
assert both_started.wait(5)
101+
assert release.wait(5)
102102
return _creation_result(config_builder, stub_dataset_profiler_results)
103103

104104
data_designer.create = MagicMock(side_effect=fake_create)
@@ -108,7 +108,7 @@ def fake_create(
108108
left_task = asyncio.create_task(data_designer.acreate(left, num_records=1, dataset_name="left"))
109109
right_task = asyncio.create_task(data_designer.acreate(right, num_records=1, dataset_name="right"))
110110
try:
111-
assert await asyncio.to_thread(both_started.wait, 2)
111+
assert await asyncio.to_thread(both_started.wait, 5)
112112
finally:
113113
release.set()
114114

packages/data-designer/tests/interface/test_composite_workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def fake_create(
7070
dataset_name: str,
7171
**kwargs,
7272
) -> DatasetCreationResults:
73-
artifact_path = Path(kwargs.pop("artifact_path", data_designer._artifact_path))
73+
artifact_path = Path(kwargs.pop("artifact_path", data_designer.artifact_path))
7474
del kwargs
7575
df = lazy.pd.DataFrame({"category": ["alpha"] * num_records, "category_copy": ["alpha"] * num_records})
7676
return _result_from_df(
@@ -177,7 +177,7 @@ def test_composite_workflow_runs_linear_stages_with_disk_handoff(
177177
assert "category_copy" in final_df.columns
178178
assert (stub_artifact_path / "linear-chain" / "stage-0-base").is_dir()
179179
assert (stub_artifact_path / "linear-chain" / "stage-1-copy").is_dir()
180-
assert data_designer._artifact_path == stub_artifact_path
180+
assert data_designer.artifact_path == stub_artifact_path
181181

182182
metadata = _load_workflow_metadata(stub_artifact_path, "linear-chain")
183183
assert [stage["status"] for stage in metadata["stages"]] == ["completed", "completed"]
@@ -880,7 +880,7 @@ def fake_create(
880880
dataset_name: str,
881881
**kwargs,
882882
) -> DatasetCreationResults:
883-
artifact_path = Path(kwargs.pop("artifact_path", data_designer._artifact_path))
883+
artifact_path = Path(kwargs.pop("artifact_path", data_designer.artifact_path))
884884
del num_records, kwargs
885885
value = "first" if dataset_name == "stage-0-first" else "final"
886886
return _result_from_df(
@@ -921,7 +921,7 @@ def fake_create(
921921
dataset_name: str,
922922
**kwargs,
923923
) -> DatasetCreationResults:
924-
artifact_path = Path(kwargs.pop("artifact_path", data_designer._artifact_path))
924+
artifact_path = Path(kwargs.pop("artifact_path", data_designer.artifact_path))
925925
del num_records, kwargs
926926
return _result_from_df(
927927
artifact_path,

packages/data-designer/tests/interface/test_data_designer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,8 @@ def test_init_with_string_path(stub_artifact_path, stub_model_providers):
491491
"""Test DataDesigner accepts string paths."""
492492
designer = DataDesigner(artifact_path=str(stub_artifact_path), model_providers=stub_model_providers)
493493
assert designer is not None
494-
assert isinstance(designer._artifact_path, Path)
494+
assert isinstance(designer.artifact_path, Path)
495+
assert designer.artifact_path == stub_artifact_path
495496

496497

497498
def test_init_with_path_object(stub_artifact_path, stub_model_providers):

0 commit comments

Comments
 (0)