Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
tracker: CompletionTracker,
row_groups: list[tuple[int, int]],
buffer_manager: RowGroupBufferManager | None = None,
side_effect_map: dict[str, str] | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: side_effect_map is positional here but all the other optional config params are keyword-only (after *). might want to move it after * for consistency?

*,
max_concurrent_row_groups: int = 3,
max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE,
Expand Down Expand Up @@ -137,6 +138,14 @@ def __init__(
instance_to_columns.setdefault(id(gen), []).append(col)
self._instance_to_columns = instance_to_columns

# Extend with side-effect columns for buffer write-back only.
# _instance_to_columns stays unchanged for completion tracking / dispatch dedup.
write_cols: dict[int, list[str]] = {k: list(v) for k, v in instance_to_columns.items()}
for se_col, primary_col in (side_effect_map or {}).items():
gen = generators[primary_col]
write_cols.setdefault(id(gen), []).append(se_col)
self._instance_to_write_columns = write_cols

# Stateful generator tracking: instance_id β†’ asyncio.Lock
self._stateful_locks: dict[int, asyncio.Lock] = {}
for col, gen in generators.items():
Expand Down Expand Up @@ -767,7 +776,7 @@ async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any

# Write results to buffer
if self._buffer_manager is not None:
output_cols = self._instance_to_columns.get(id(generator), [task.column])
output_cols = self._instance_to_write_columns.get(id(generator), [task.column])
for col in output_cols:
if col in result_df.columns:
values = result_df[col].tolist()
Expand All @@ -793,7 +802,7 @@ async def _run_cell(self, task: Task, generator: ColumnGenerator) -> Any:

# Write back to buffer
if self._buffer_manager is not None and not self._tracker.is_dropped(task.row_group, task.row_index):
output_cols = self._instance_to_columns.get(id(generator), [task.column])
output_cols = self._instance_to_write_columns.get(id(generator), [task.column])
for col in output_cols:
if col in result:
self._buffer_manager.update_cell(task.row_group, task.row_index, col, result[col])
Expand All @@ -817,7 +826,7 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:

# Merge result columns back to buffer
if self._buffer_manager is not None:
output_cols = self._instance_to_columns.get(id(generator), [task.column])
output_cols = self._instance_to_write_columns.get(id(generator), [task.column])
active_rows = rg_size - len(pre_dropped)
if len(result_df) != active_rows:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,19 @@ def _prepare_async_run(
"""
strategies: dict[str, GenerationStrategy] = {}
gen_map: dict[str, ColumnGenerator] = {}
side_effect_map: dict[str, str] = {}
for gen in generators:
if isinstance(gen.config, MultiColumnConfig):
for sub in gen.config.columns:
strategies[sub.name] = gen.get_generation_strategy()
gen_map[sub.name] = gen
for se_col in sub.side_effect_columns:
side_effect_map[se_col] = sub.name
else:
strategies[gen.config.name] = gen.get_generation_strategy()
gen_map[gen.config.name] = gen
for se_col in gen.config.side_effect_columns:
side_effect_map[se_col] = gen.config.name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecutionGraph.create() on line 342 already builds the same {side_effect_col: primary_col} mapping internally (_side_effect_map). might be worth exposing that as a read-only property on the graph and reading it here instead of building it independently - keeps the graph as the single source of truth and avoids the two copies drifting apart if side-effect logic gets more complex later.


graph = ExecutionGraph.create(self._column_configs, strategies)

Expand Down Expand Up @@ -379,6 +384,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_manager,
side_effect_map=side_effect_map,
max_submitted_tasks=DEFAULT_TASK_POOL_SIZE,
max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, LLM_WAIT_POOL_MULTIPLIER * aggregate),
on_finalize_row_group=on_finalize_row_group,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1551,3 +1551,77 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None:
f"First judge dispatched at {first_judge_dispatched:.4f}, "
f"last gen dispatched at {last_gen_dispatched:.4f}."
)


class MockCellGeneratorWithSideEffect(ColumnGenerator[ExpressionColumnConfig]):
"""Cell generator that produces a side-effect column alongside its primary output."""

def __init__(self, *args: Any, side_effect_col: str, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._side_effect_col = side_effect_col

@staticmethod
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.CELL_BY_CELL

def generate(self, data: dict) -> dict:
data[self.config.name] = f"primary_{data.get('seed', '?')}"
data[self._side_effect_col] = f"side_effect_{data.get('seed', '?')}"
return data


@pytest.mark.asyncio(loop_scope="session")
async def test_scheduler_side_effect_columns_written_to_buffer() -> None:
"""Side-effect columns (e.g. __reasoning_content) are persisted to the buffer.

Reproduces the bug where a generator produces extra columns (like trace or
reasoning_content) that aren't tracked in _instance_to_columns. Downstream
columns that reference these side-effect values must find them in the buffer.
"""
provider = _mock_provider()
side_effect_col = "answer__reasoning_content"

configs = [
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
LLMTextColumnConfig(
name="answer", prompt="{{ seed }}", model_alias=MODEL_ALIAS, extract_reasoning_content=True
),
LLMTextColumnConfig(name="judge", prompt=f"{{{{ {side_effect_col} }}}}", model_alias=MODEL_ALIAS),
]
strategies = {
"seed": GenerationStrategy.FULL_COLUMN,
"answer": GenerationStrategy.CELL_BY_CELL,
"judge": GenerationStrategy.CELL_BY_CELL,
}

answer_gen = MockCellGeneratorWithSideEffect(
config=_expr_config("answer"),
resource_provider=provider,
side_effect_col=side_effect_col,
)
generators: dict[str, ColumnGenerator] = {
"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
"answer": answer_gen,
"judge": MockCellGenerator(config=_expr_config("judge"), resource_provider=provider),
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the judge mock doesn't actually consume answer__reasoning_content from its input data, so this doesn't fully prove that downstream columns can read the side-effect value from the buffer. maybe worth swapping in a mock that asserts data[side_effect_col] is present? current test still catches the original bug though, so no blocker.


graph = ExecutionGraph.create(configs, strategies)
row_groups = [(0, 3)]
tracker = CompletionTracker.with_graph(graph, row_groups)
buffer_manager = RowGroupBufferManager(MagicMock())

scheduler = AsyncTaskScheduler(
generators=generators,
graph=graph,
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_manager,
side_effect_map={side_effect_col: "answer"},
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this test exercises the scheduler directly, which proves the mechanism works, but doesn't go through _prepare_async_run() where the production wiring happens. if someone refactors the builder and forgets to pass side_effect_map, this test would still pass. maybe worth an integration-level test through build_preview() too?

await scheduler.run()

assert tracker.is_row_group_complete(0, 3, ["seed", "answer", "judge"])
for ri in range(3):
row = buffer_manager.get_row(0, ri)
assert side_effect_col in row, f"Side-effect column missing from row {ri}"
assert row[side_effect_col].startswith("side_effect_")
Loading