diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 8d1d427e..b5740708 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -87,6 +87,7 @@ def __init__( tracker: CompletionTracker, row_groups: list[tuple[int, int]], buffer_manager: RowGroupBufferManager | None = None, + side_effect_map: dict[str, str] | None = None, *, max_concurrent_row_groups: int = 3, max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE, @@ -137,6 +138,14 @@ def __init__( instance_to_columns.setdefault(id(gen), []).append(col) self._instance_to_columns = instance_to_columns + # Extend with side-effect columns for buffer write-back only. + # _instance_to_columns stays unchanged for completion tracking / dispatch dedup. + write_cols: dict[int, list[str]] = {k: list(v) for k, v in instance_to_columns.items()} + for se_col, primary_col in (side_effect_map or {}).items(): + gen = generators[primary_col] + write_cols.setdefault(id(gen), []).append(se_col) + self._instance_to_write_columns = write_cols + # Stateful generator tracking: instance_id → asyncio.Lock self._stateful_locks: dict[int, asyncio.Lock] = {} for col, gen in generators.items(): @@ -767,7 +776,7 @@ async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any # Write results to buffer if self._buffer_manager is not None: - output_cols = self._instance_to_columns.get(id(generator), [task.column]) + output_cols = self._instance_to_write_columns.get(id(generator), [task.column]) for col in output_cols: if col in result_df.columns: values = result_df[col].tolist() @@ -793,7 +802,7 @@ async def _run_cell(self, task: Task, generator: ColumnGenerator) -> Any: # Write back to buffer if self._buffer_manager is not None and not self._tracker.is_dropped(task.row_group, task.row_index): - output_cols = self._instance_to_columns.get(id(generator), [task.column]) + output_cols = self._instance_to_write_columns.get(id(generator), [task.column]) for col in output_cols: if col in result: self._buffer_manager.update_cell(task.row_group, task.row_index, col, result[col]) @@ -817,7 +826,7 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: # Merge result columns back to buffer if self._buffer_manager is not None: - output_cols = self._instance_to_columns.get(id(generator), [task.column]) + output_cols = self._instance_to_write_columns.get(id(generator), [task.column]) active_rows = rg_size - len(pre_dropped) if len(result_df) != active_rows: raise ValueError( diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index f9705434..d8d0452e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -325,14 +325,19 @@ def _prepare_async_run( """ strategies: dict[str, GenerationStrategy] = {} gen_map: dict[str, ColumnGenerator] = {} + side_effect_map: dict[str, str] = {} for gen in generators: if isinstance(gen.config, MultiColumnConfig): for sub in gen.config.columns: strategies[sub.name] = gen.get_generation_strategy() gen_map[sub.name] = gen + for se_col in sub.side_effect_columns: + side_effect_map[se_col] = sub.name else: strategies[gen.config.name] = gen.get_generation_strategy() gen_map[gen.config.name] = gen + for se_col in gen.config.side_effect_columns: + side_effect_map[se_col] = gen.config.name graph = ExecutionGraph.create(self._column_configs, strategies) @@ -379,6 +384,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_manager, + side_effect_map=side_effect_map, max_submitted_tasks=DEFAULT_TASK_POOL_SIZE, max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, LLM_WAIT_POOL_MULTIPLIER * aggregate), on_finalize_row_group=on_finalize_row_group, diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index 466ed697..a9bd76b7 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -1551,3 +1551,77 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: f"First judge dispatched at {first_judge_dispatched:.4f}, " f"last gen dispatched at {last_gen_dispatched:.4f}." ) + + +class MockCellGeneratorWithSideEffect(ColumnGenerator[ExpressionColumnConfig]): + """Cell generator that produces a side-effect column alongside its primary output.""" + + def __init__(self, *args: Any, side_effect_col: str, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._side_effect_col = side_effect_col + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data[self.config.name] = f"primary_{data.get('seed', '?')}" + data[self._side_effect_col] = f"side_effect_{data.get('seed', '?')}" + return data + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_side_effect_columns_written_to_buffer() -> None: + """Side-effect columns (e.g. __reasoning_content) are persisted to the buffer. + + Reproduces the bug where a generator produces extra columns (like trace or + reasoning_content) that aren't tracked in _instance_to_columns. Downstream + columns that reference these side-effect values must find them in the buffer. + """ + provider = _mock_provider() + side_effect_col = "answer__reasoning_content" + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig( + name="answer", prompt="{{ seed }}", model_alias=MODEL_ALIAS, extract_reasoning_content=True + ), + LLMTextColumnConfig(name="judge", prompt=f"{{{{ {side_effect_col} }}}}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "answer": GenerationStrategy.CELL_BY_CELL, + "judge": GenerationStrategy.CELL_BY_CELL, + } + + answer_gen = MockCellGeneratorWithSideEffect( + config=_expr_config("answer"), + resource_provider=provider, + side_effect_col=side_effect_col, + ) + generators: dict[str, ColumnGenerator] = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "answer": answer_gen, + "judge": MockCellGenerator(config=_expr_config("judge"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + buffer_manager = RowGroupBufferManager(MagicMock()) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + side_effect_map={side_effect_col: "answer"}, + ) + await scheduler.run() + + assert tracker.is_row_group_complete(0, 3, ["seed", "answer", "judge"]) + for ri in range(3): + row = buffer_manager.get_row(0, ri) + assert side_effect_col in row, f"Side-effect column missing from row {ri}" + assert row[side_effect_col].startswith("side_effect_")