55
66import hashlib
77import json
8+ import logging
89import shutil
910import time
10- from collections .abc import Callable , Iterator
11- from contextlib import contextmanager
11+ from collections .abc import Callable , ItemsView , Iterator , KeysView
1212from dataclasses import dataclass
1313from pathlib import Path
1414from typing import TYPE_CHECKING , Any
2424from data_designer .interface .results import DatasetCreationResults
2525
2626if TYPE_CHECKING :
27+ import pandas as pd
28+
29+ from data_designer .config .analysis .dataset_profiler import DatasetProfilerResults
2730 from data_designer .interface .data_designer import DataDesigner
2831
2932
33+ logger = logging .getLogger (__name__ )
34+
3035OnSuccessCallback = Callable [[Path ], Path | str ]
3136
3237
@@ -50,13 +55,21 @@ class SkippedStageResult:
5055
5156
5257class CompositeWorkflowResults :
58+ """Results for a composite workflow run.
59+
60+ Per-stage entries are the original ``DataDesigner.create()`` results. If a
61+ stage uses ``on_success``, metadata and downstream seeding use the callback
62+ output path while the stage result still points at the stage's generated
63+ dataset.
64+ """
65+
5366 def __init__ (
5467 self ,
5568 * ,
5669 name : str ,
5770 stage_results : dict [str , DatasetCreationResults | SkippedStageResult ],
5871 final_stage_name : str ,
59- ):
72+ ) -> None :
6073 self .name = name
6174 self .stage_results = stage_results
6275 self .final_stage_name = final_stage_name
@@ -67,10 +80,10 @@ def __getitem__(self, stage_name: str) -> DatasetCreationResults | SkippedStageR
6780 def __iter__ (self ) -> Iterator [str ]:
6881 return iter (self .stage_results )
6982
70- def keys (self ):
83+ def keys (self ) -> KeysView [ str ] :
7184 return self .stage_results .keys ()
7285
73- def items (self ):
86+ def items (self ) -> ItemsView [ str , DatasetCreationResults | SkippedStageResult ] :
7487 return self .stage_results .items ()
7588
7689 @property
@@ -80,24 +93,24 @@ def final_result(self) -> DatasetCreationResults:
8093 raise DataDesignerWorkflowError (f"Final stage { self .final_stage_name !r} was skipped: { result .status } ." )
8194 return result
8295
83- def load_dataset (self ):
96+ def load_dataset (self ) -> pd . DataFrame :
8497 return self .final_result .load_dataset ()
8598
86- def load_analysis (self ):
99+ def load_analysis (self ) -> DatasetProfilerResults :
87100 return self .final_result .load_analysis ()
88101
89102 def count_records (self ) -> int :
90103 return self .final_result .count_records ()
91104
92- def export (self , * args , ** kwargs ) :
105+ def export (self , * args : Any , ** kwargs : Any ) -> Path :
93106 return self .final_result .export (* args , ** kwargs )
94107
95- def push_to_hub (self , * args , ** kwargs ) :
108+ def push_to_hub (self , * args : Any , ** kwargs : Any ) -> str :
96109 return self .final_result .push_to_hub (* args , ** kwargs )
97110
98111
99112class CompositeWorkflow :
100- def __init__ (self , * , name : str , data_designer : DataDesigner ):
113+ def __init__ (self , * , name : str , data_designer : DataDesigner ) -> None :
101114 _validate_dir_name (name , "workflow name" )
102115 self .name = name
103116 self ._data_designer = data_designer
@@ -123,7 +136,7 @@ def add_stage(
123136 self ._stages .append (
124137 _WorkflowStage (
125138 name = name ,
126- config_builder = config_builder ,
139+ config_builder = _clone_config_builder ( config_builder ) ,
127140 depends_on = (self ._stages [- 1 ].name ,) if self ._stages else (),
128141 num_records = num_records ,
129142 on_success = on_success ,
@@ -136,6 +149,7 @@ def add_stage(
136149 return self
137150
138151 def run (self ) -> CompositeWorkflowResults :
152+ """Run all stages from scratch, replacing deterministic stage directories."""
139153 if not self ._stages :
140154 raise DataDesignerWorkflowError (f"Workflow { self .name !r} has no stages." )
141155
@@ -174,6 +188,12 @@ def run(self) -> CompositeWorkflowResults:
174188
175189 stage_builder = _clone_config_builder (stage .config_builder )
176190 if previous_seed_path is not None :
191+ if stage_builder .get_seed_config () is not None :
192+ logger .warning (
193+ "Stage %r has a seed dataset; workflow will seed it from upstream stage %r." ,
194+ stage .name ,
195+ previous_stage_name ,
196+ )
177197 stage_builder .with_seed_dataset (
178198 _local_seed_source_from_path (previous_seed_path ),
179199 sampling_strategy = stage .sampling_strategy ,
@@ -206,12 +226,12 @@ def run(self) -> CompositeWorkflowResults:
206226
207227 start_time = time .monotonic ()
208228 try :
209- with _temporary_artifact_path ( self ._data_designer , workflow_path ):
210- result = self . _data_designer . create (
211- stage_builder ,
212- num_records = num_records ,
213- dataset_name = stage_dir_name ,
214- )
229+ result = self ._data_designer . create (
230+ stage_builder ,
231+ num_records = num_records ,
232+ dataset_name = stage_dir_name ,
233+ artifact_path = workflow_path ,
234+ )
215235 actual_records = result .count_records ()
216236 output_seed_path = result .artifact_storage .final_dataset_path
217237 callback_output_path = None
@@ -263,16 +283,6 @@ def _clone_config_builder(config_builder: DataDesignerConfigBuilder) -> DataDesi
263283 return DataDesignerConfigBuilder .from_config (BuilderConfig (data_designer = config_builder .build ()))
264284
265285
266- @contextmanager
267- def _temporary_artifact_path (data_designer : DataDesigner , artifact_path : Path ):
268- original_artifact_path = data_designer ._artifact_path
269- data_designer ._artifact_path = artifact_path
270- try :
271- yield
272- finally :
273- data_designer ._artifact_path = original_artifact_path
274-
275-
276286def _stage_dir_name (index : int , name : str ) -> str :
277287 return f"stage-{ index } -{ name } "
278288
0 commit comments