-
Notifications
You must be signed in to change notification settings - Fork 136
fix: persist side-effect columns to row buffer in async engine #524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -325,14 +325,19 @@ def _prepare_async_run( | |
| """ | ||
| strategies: dict[str, GenerationStrategy] = {} | ||
| gen_map: dict[str, ColumnGenerator] = {} | ||
| side_effect_map: dict[str, str] = {} | ||
| for gen in generators: | ||
| if isinstance(gen.config, MultiColumnConfig): | ||
| for sub in gen.config.columns: | ||
| strategies[sub.name] = gen.get_generation_strategy() | ||
| gen_map[sub.name] = gen | ||
| for se_col in sub.side_effect_columns: | ||
| side_effect_map[se_col] = sub.name | ||
| else: | ||
| strategies[gen.config.name] = gen.get_generation_strategy() | ||
| gen_map[gen.config.name] = gen | ||
| for se_col in gen.config.side_effect_columns: | ||
| side_effect_map[se_col] = gen.config.name | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| graph = ExecutionGraph.create(self._column_configs, strategies) | ||
|
|
||
|
|
@@ -379,6 +384,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: | |
| tracker=tracker, | ||
| row_groups=row_groups, | ||
| buffer_manager=buffer_manager, | ||
| side_effect_map=side_effect_map, | ||
| max_submitted_tasks=DEFAULT_TASK_POOL_SIZE, | ||
| max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, LLM_WAIT_POOL_MULTIPLIER * aggregate), | ||
| on_finalize_row_group=on_finalize_row_group, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1551,3 +1551,77 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: | |
| f"First judge dispatched at {first_judge_dispatched:.4f}, " | ||
| f"last gen dispatched at {last_gen_dispatched:.4f}." | ||
| ) | ||
|
|
||
|
|
||
| class MockCellGeneratorWithSideEffect(ColumnGenerator[ExpressionColumnConfig]): | ||
| """Cell generator that produces a side-effect column alongside its primary output.""" | ||
|
|
||
| def __init__(self, *args: Any, side_effect_col: str, **kwargs: Any) -> None: | ||
| super().__init__(*args, **kwargs) | ||
| self._side_effect_col = side_effect_col | ||
|
|
||
| @staticmethod | ||
| def get_generation_strategy() -> GenerationStrategy: | ||
| return GenerationStrategy.CELL_BY_CELL | ||
|
|
||
| def generate(self, data: dict) -> dict: | ||
| data[self.config.name] = f"primary_{data.get('seed', '?')}" | ||
| data[self._side_effect_col] = f"side_effect_{data.get('seed', '?')}" | ||
| return data | ||
|
|
||
|
|
||
| @pytest.mark.asyncio(loop_scope="session") | ||
| async def test_scheduler_side_effect_columns_written_to_buffer() -> None: | ||
| """Side-effect columns (e.g. __reasoning_content) are persisted to the buffer. | ||
|
|
||
| Reproduces the bug where a generator produces extra columns (like trace or | ||
| reasoning_content) that aren't tracked in _instance_to_columns. Downstream | ||
| columns that reference these side-effect values must find them in the buffer. | ||
| """ | ||
| provider = _mock_provider() | ||
| side_effect_col = "answer__reasoning_content" | ||
|
|
||
| configs = [ | ||
| SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), | ||
| LLMTextColumnConfig( | ||
| name="answer", prompt="{{ seed }}", model_alias=MODEL_ALIAS, extract_reasoning_content=True | ||
| ), | ||
| LLMTextColumnConfig(name="judge", prompt=f"{{{{ {side_effect_col} }}}}", model_alias=MODEL_ALIAS), | ||
| ] | ||
| strategies = { | ||
| "seed": GenerationStrategy.FULL_COLUMN, | ||
| "answer": GenerationStrategy.CELL_BY_CELL, | ||
| "judge": GenerationStrategy.CELL_BY_CELL, | ||
| } | ||
|
|
||
| answer_gen = MockCellGeneratorWithSideEffect( | ||
| config=_expr_config("answer"), | ||
| resource_provider=provider, | ||
| side_effect_col=side_effect_col, | ||
| ) | ||
| generators: dict[str, ColumnGenerator] = { | ||
| "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), | ||
| "answer": answer_gen, | ||
| "judge": MockCellGenerator(config=_expr_config("judge"), resource_provider=provider), | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: the judge mock doesn't actually consume |
||
|
|
||
| graph = ExecutionGraph.create(configs, strategies) | ||
| row_groups = [(0, 3)] | ||
| tracker = CompletionTracker.with_graph(graph, row_groups) | ||
| buffer_manager = RowGroupBufferManager(MagicMock()) | ||
|
|
||
| scheduler = AsyncTaskScheduler( | ||
| generators=generators, | ||
| graph=graph, | ||
| tracker=tracker, | ||
| row_groups=row_groups, | ||
| buffer_manager=buffer_manager, | ||
| side_effect_map={side_effect_col: "answer"}, | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this test exercises the scheduler directly, which proves the mechanism works, but doesn't go through |
||
| await scheduler.run() | ||
|
|
||
| assert tracker.is_row_group_complete(0, 3, ["seed", "answer", "judge"]) | ||
| for ri in range(3): | ||
| row = buffer_manager.get_row(0, ri) | ||
| assert side_effect_col in row, f"Side-effect column missing from row {ri}" | ||
| assert row[side_effect_col].startswith("side_effect_") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor:
side_effect_mapis positional here but all the other optional config params are keyword-only (after*). might want to move it after*for consistency?