Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,24 @@ def __init__(
self._disable_early_shutdown = disable_early_shutdown
self._early_shutdown = False

# Multi-column dedup: group output columns by generator identity
instance_to_columns: dict[int, list[str]] = {}
# Multi-column dedup: group output columns by generator identity.
# _gen_instance_to_columns holds only real (graph-registered) columns
# and is used for completion tracking.
# _gen_instance_to_columns_including_side_effects extends that with
# side-effect columns for buffer writes only.
gen_instance_to_columns: dict[int, list[str]] = {}
for col, gen in generators.items():
instance_to_columns.setdefault(id(gen), []).append(col)
self._instance_to_columns = instance_to_columns
gen_instance_to_columns.setdefault(id(gen), []).append(col)
self._gen_instance_to_columns = gen_instance_to_columns

seen_cols: set[str] = {col for col in generators}
gen_instance_to_columns_incl_se: dict[int, list[str]] = {k: list(v) for k, v in gen_instance_to_columns.items()}
for col, gen in generators.items():
for side_effect_col in getattr(gen.config, "side_effect_columns", []):
if side_effect_col not in seen_cols:
gen_instance_to_columns_incl_se.setdefault(id(gen), []).append(side_effect_col)
seen_cols.add(side_effect_col)
self._gen_instance_to_columns_including_side_effects = gen_instance_to_columns_incl_se

# Stateful generator tracking: instance_id → asyncio.Lock
self._stateful_locks: dict[int, asyncio.Lock] = {}
Expand Down Expand Up @@ -356,7 +369,7 @@ async def _salvage_rounds(
self._dispatched.discard(
Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch")
)
for sibling in self._instance_to_columns.get(gid, []):
for sibling in self._gen_instance_to_columns.get(gid, []):
if sibling != task.column:
self._dispatched.discard(
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch")
Expand All @@ -377,7 +390,7 @@ async def _salvage_rounds(
)
# Re-mark sibling columns as dispatched to mirror _dispatch_seeds
# and prevent _drain_frontier from re-dispatching them.
for sibling in self._instance_to_columns.get(gid, []):
for sibling in self._gen_instance_to_columns.get(gid, []):
if sibling != task.column:
self._dispatched.add(
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch")
Expand Down Expand Up @@ -620,7 +633,7 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
self._dispatched.add(task)
self._dispatched.add(batch_alias)
# Also mark all sibling output columns as dispatched (multi-column dedup)
for sibling_col in self._instance_to_columns.get(gid, []):
for sibling_col in self._gen_instance_to_columns.get(gid, []):
if sibling_col != col:
self._dispatched.add(
Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="from_scratch")
Expand Down Expand Up @@ -665,7 +678,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None:
trace.dispatched_at = time.perf_counter()

generator = self._generators[task.column]
output_cols = self._instance_to_columns.get(id(generator), [task.column])
output_cols = self._gen_instance_to_columns.get(id(generator), [task.column])
retryable = False
# When True, skip removing from _dispatched so the task isn't re-dispatched
# from the frontier (it was never completed, so it stays in the frontier).
Expand Down Expand Up @@ -765,10 +778,10 @@ async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any
else:
result_df = await generator.agenerate(lazy.pd.DataFrame())

# Write results to buffer
# Write results to buffer (include side-effect columns)
if self._buffer_manager is not None:
output_cols = self._instance_to_columns.get(id(generator), [task.column])
for col in output_cols:
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
for col in write_cols:
if col in result_df.columns:
values = result_df[col].tolist()
self._buffer_manager.update_batch(task.row_group, col, values)
Expand All @@ -791,10 +804,10 @@ async def _run_cell(self, task: Task, generator: ColumnGenerator) -> Any:

result = await generator.agenerate(row_data)

# Write back to buffer
# Write back to buffer (include side-effect columns)
if self._buffer_manager is not None and not self._tracker.is_dropped(task.row_group, task.row_index):
output_cols = self._instance_to_columns.get(id(generator), [task.column])
for col in output_cols:
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
for col in write_cols:
if col in result:
self._buffer_manager.update_cell(task.row_group, task.row_index, col, result[col])

Expand All @@ -815,9 +828,9 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:

result_df = await generator.agenerate(batch_df)

# Merge result columns back to buffer
# Merge result columns back to buffer (include side-effect columns)
if self._buffer_manager is not None:
output_cols = self._instance_to_columns.get(id(generator), [task.column])
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
active_rows = rg_size - len(pre_dropped)
if len(result_df) != active_rows:
raise ValueError(
Expand All @@ -830,7 +843,7 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:
continue
# Skip writing to rows dropped by concurrent tasks during the await
if not self._buffer_manager.is_dropped(task.row_group, ri):
for col in output_cols:
for col in write_cols:
if col in result_df.columns:
self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col])
result_idx += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import data_designer.lazy_heavy_imports as lazy
from data_designer.config.column_types import ColumnConfigT
from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
from data_designer.logging import LOG_INDENT

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +29,18 @@ def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> li

side_effect_dict = {n: list(c.side_effect_columns) for n, c in dag_column_config_dict.items()}

side_effect_to_producer: dict[str, str] = {}
for producer, cols in side_effect_dict.items():
for col in cols:
existing = side_effect_to_producer.get(col)
if existing is not None and existing != producer:
raise ConfigCompilationError(
f"Side-effect column {col!r} is already produced by {existing!r}; "
f"cannot register a second producer {producer!r}. "
f"Use distinct side-effect column names for each pipeline stage."
)
side_effect_to_producer[col] = producer

logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph")
for name, col in dag_column_config_dict.items():
dag.add_node(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import logging
import math
from collections import deque

Expand All @@ -11,9 +12,11 @@
DatasetBuilderColumnConfigT,
MultiColumnConfig,
)
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.task_model import SliceRef

logger = logging.getLogger(__name__)


class ExecutionGraph:
"""Column-level static execution graph built from column configs.
Expand Down Expand Up @@ -105,7 +108,19 @@ def add_edge(self, upstream: str, downstream: str) -> None:
self._downstream.setdefault(upstream, set()).add(downstream)

def set_side_effect(self, side_effect_col: str, producer: str) -> None:
"""Map a side-effect column name to its producing column."""
"""Map a side-effect column name to its producing column.

Each side-effect column must have exactly one producer. Duplicate
registrations from a different producer are a configuration error -
use distinct column names for each pipeline stage instead.
"""
existing = self._side_effect_map.get(side_effect_col)
if existing is not None and existing != producer:
raise ConfigCompilationError(
f"Side-effect column {side_effect_col!r} is already produced by {existing!r}; "
f"cannot register a second producer {producer!r}. "
f"Use distinct side-effect column names for each pipeline stage."
)
self._side_effect_map[side_effect_col] = producer

def resolve_side_effect(self, column: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,51 @@ async def test_scheduler_rg_semaphore_deadlock_with_transient_failures() -> None
assert tracker.is_row_group_complete(1, 2, ["seed", "col"])


def test_side_effect_columns_separated_from_completion_tracking() -> None:
"""Side-effect columns must appear in _gen_instance_to_columns_including_side_effects
(buffer writes) but NOT in _gen_instance_to_columns (completion tracking), because
they are not registered in the execution graph and would cause KeyError in
CompletionTracker.
"""
graph = ExecutionGraph()
graph.add_column("seed", GenerationStrategy.FULL_COLUMN)
graph.add_column("primary", GenerationStrategy.CELL_BY_CELL)
graph.add_edge(upstream="seed", downstream="primary")

row_groups = [(0, 2)]
tracker = CompletionTracker.with_graph(graph, row_groups)

provider = _mock_provider()
seed_gen = MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider)
cell_gen = MockCellGenerator(config=_expr_config("primary"), resource_provider=provider)
# Replace the config with a mock that reports side-effect columns.
mock_config = MagicMock()
mock_config.side_effect_columns = ["side_a", "side_b"]
object.__setattr__(cell_gen, "_config", mock_config)

generators: dict[str, ColumnGenerator] = {"seed": seed_gen, "primary": cell_gen}

scheduler = AsyncTaskScheduler(
generators=generators,
graph=graph,
tracker=tracker,
row_groups=row_groups,
)

cell_id = id(cell_gen)

# Completion tracking dict: only real columns
assert "side_a" not in scheduler._gen_instance_to_columns.get(cell_id, [])
assert "side_b" not in scheduler._gen_instance_to_columns.get(cell_id, [])
assert "primary" in scheduler._gen_instance_to_columns.get(cell_id, [])

# Buffer write dict: includes side-effect columns
write_cols = scheduler._gen_instance_to_columns_including_side_effects.get(cell_id, [])
assert "primary" in write_cols
assert "side_a" in write_cols
assert "side_b" in write_cols


# -- TrackingSemaphore tests ---------------------------------------------------


Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest

from data_designer.config.column_configs import (
CustomColumnConfig,
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
Expand All @@ -13,12 +16,13 @@
ValidationColumnConfig,
)
from data_designer.config.column_types import DataDesignerColumnType
from data_designer.config.custom_column import custom_column_generator
from data_designer.config.sampler_params import SamplerType
from data_designer.config.utils.code_lang import CodeLang
from data_designer.config.validator_params import CodeValidatorParams
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError

MODEL_ALIAS = "stub-model-alias"

Expand Down Expand Up @@ -111,3 +115,23 @@ def test_circular_dependencies():
)
with pytest.raises(DAGCircularDependencyError, match="cyclic dependencies"):
topologically_sort_column_configs(column_configs)


def test_duplicate_side_effect_producers_raises() -> None:
"""Two custom columns declaring the same side-effect column is a configuration error."""

@custom_column_generator(required_columns=["text"], side_effect_columns=["shared_col"])
def gen_a(row: dict[str, Any]) -> dict[str, Any]:
return row

@custom_column_generator(required_columns=["text"], side_effect_columns=["shared_col"])
def gen_b(row: dict[str, Any]) -> dict[str, Any]:
return row

column_configs = [
LLMTextColumnConfig(name="text", prompt="hello", model_alias=MODEL_ALIAS),
CustomColumnConfig(name="col_a", generator_function=gen_a),
CustomColumnConfig(name="col_b", generator_function=gen_b),
]
with pytest.raises(ConfigCompilationError, match="already produced by"):
topologically_sort_column_configs(column_configs)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from data_designer.config.utils.code_lang import CodeLang
from data_designer.config.validator_params import CodeValidatorParams
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph
from data_designer.engine.dataset_builders.utils.task_model import SliceRef

Expand Down Expand Up @@ -156,6 +156,17 @@ def test_side_effect_name_collision_prefers_real_column() -> None:
assert graph.get_downstream_columns("summary") == set()


def test_side_effect_collision_raises() -> None:
"""Two producers for the same side-effect column is a configuration error."""
graph = ExecutionGraph()
graph.add_column("producer_a", GenerationStrategy.CELL_BY_CELL)
graph.add_column("producer_b", GenerationStrategy.CELL_BY_CELL)

graph.set_side_effect("shared_se", "producer_a")
with pytest.raises(ConfigCompilationError, match="already produced by 'producer_a'"):
graph.set_side_effect("shared_se", "producer_b")


# -- Validation tests -------------------------------------------------------


Expand Down
Loading