Skip to content

Commit 4f4c461

Browse files
committed
address cr feedback
1 parent 9fbf971 commit 4f4c461

4 files changed

Lines changed: 155 additions & 13 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def create(
7474
for se_col in sub.side_effect_columns:
7575
graph.set_side_effect(se_col, name)
7676

77-
graph.set_required_columns(name, list(sub.required_columns))
7877
graph.set_propagate_skip(name, sub.propagate_skip)
7978
if sub.skip is not None:
8079
graph.set_skip_config(name, sub.skip)
@@ -90,6 +89,7 @@ def create(
9089

9190
for sub in sub_configs:
9291
name = sub.name
92+
resolved_required: list[str] = []
9393
for req in sub.required_columns:
9494
resolved = graph.resolve_side_effect(req)
9595
if resolved not in known_columns:
@@ -98,7 +98,10 @@ def create(
9898
)
9999
if resolved == name:
100100
continue
101+
if resolved not in resolved_required:
102+
resolved_required.append(resolved)
101103
graph.add_edge(upstream=resolved, downstream=name)
104+
graph.set_required_columns(name, resolved_required)
102105

103106
if sub.skip is not None:
104107
for skip_col in sub.skip.columns:
@@ -135,7 +138,7 @@ def set_side_effect(self, side_effect_col: str, producer: str) -> None:
135138
self._producer_to_side_effect_map.setdefault(producer, []).append(side_effect_col)
136139

137140
def set_required_columns(self, column: str, required: list[str]) -> None:
138-
"""Store the config-level ``required_columns`` for *column*."""
141+
"""Store producer-resolved ``required_columns`` for skip propagation."""
139142
self._required_columns[column] = required
140143

141144
def set_propagate_skip(self, column: str, propagate: bool) -> None:
@@ -165,7 +168,7 @@ def get_downstream_columns(self, column: str) -> set[str]:
165168
return set(self._downstream.get(column, set()))
166169

167170
def get_required_columns(self, column: str) -> list[str]:
168-
"""Config-level ``required_columns`` for *column* (data dependencies only)."""
171+
"""Producer-resolved ``required_columns`` for *column* (data dependencies only)."""
169172
return list(self._required_columns.get(column, []))
170173

171174
def get_skip_config(self, column: str) -> SkipConfig | None:

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
import data_designer.lazy_heavy_imports as lazy
13+
from data_designer.config.base import SkipConfig
1314
from data_designer.config.column_configs import (
1415
CustomColumnConfig,
1516
ExpressionColumnConfig,
@@ -1563,8 +1564,6 @@ async def test_scheduler_skip_cell_by_cell_with_propagation() -> None:
15631564
Pipeline: seed(sampler) -> review(cell, skip.when seed<2) -> complaint(cell, propagate_skip)
15641565
Rows with seed < 2 should be skipped for review and propagated to complaint.
15651566
"""
1566-
from data_designer.config.base import SkipConfig
1567-
15681567
provider = _mock_provider()
15691568
num_records = 4
15701569

@@ -1640,15 +1639,97 @@ def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame:
16401639
assert row.get("complaint") is not None, f"row {ri}: complaint should be generated (seed={seed_val})"
16411640

16421641

1642+
@pytest.mark.asyncio(loop_scope="session")
1643+
async def test_scheduler_skip_propagates_through_side_effect_dependency() -> None:
1644+
"""A downstream dependency on a skipped side-effect should auto-skip.
1645+
1646+
Pipeline: seed(sampler) -> review(cell, skip.when seed<2, produces
1647+
review__trace) -> complaint(cell, depends on review__trace,
1648+
propagate_skip=True).
1649+
"""
1650+
provider = _mock_provider()
1651+
num_records = 4
1652+
1653+
configs = [
1654+
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
1655+
LLMTextColumnConfig(
1656+
name="review",
1657+
prompt="{{ seed }}",
1658+
model_alias=MODEL_ALIAS,
1659+
with_trace="last_message",
1660+
skip=SkipConfig(when="{{ seed < 2 }}"),
1661+
),
1662+
LLMTextColumnConfig(
1663+
name="complaint",
1664+
prompt="{{ review__trace }}",
1665+
model_alias=MODEL_ALIAS,
1666+
propagate_skip=True,
1667+
),
1668+
]
1669+
strategies = {
1670+
"seed": GenerationStrategy.FULL_COLUMN,
1671+
"review": GenerationStrategy.CELL_BY_CELL,
1672+
"complaint": GenerationStrategy.CELL_BY_CELL,
1673+
}
1674+
1675+
class IntSeedGenerator(FromScratchColumnGenerator[ExpressionColumnConfig]):
1676+
@staticmethod
1677+
def get_generation_strategy() -> GenerationStrategy:
1678+
return GenerationStrategy.FULL_COLUMN
1679+
1680+
def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame:
1681+
return data
1682+
1683+
def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame:
1684+
return lazy.pd.DataFrame({"seed": list(range(num_records))})
1685+
1686+
generators: dict[str, ColumnGenerator] = {
1687+
"seed": IntSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
1688+
"review": MockCellGenerator(config=_expr_config("review"), resource_provider=provider),
1689+
"complaint": MockCellGenerator(config=_expr_config("complaint"), resource_provider=provider),
1690+
}
1691+
1692+
storage = MagicMock()
1693+
storage.dataset_name = "test"
1694+
storage.get_file_paths.return_value = {}
1695+
buffer_mgr = RowGroupBufferManager(storage)
1696+
1697+
graph = ExecutionGraph.create(configs, strategies)
1698+
row_groups = [(0, num_records)]
1699+
tracker = CompletionTracker.with_graph(graph, row_groups)
1700+
1701+
scheduler = AsyncTaskScheduler(
1702+
generators=generators,
1703+
graph=graph,
1704+
tracker=tracker,
1705+
row_groups=row_groups,
1706+
buffer_manager=buffer_mgr,
1707+
trace=True,
1708+
num_records=num_records,
1709+
buffer_size=num_records,
1710+
)
1711+
await asyncio.wait_for(scheduler.run(), timeout=10.0)
1712+
1713+
assert tracker.is_row_group_complete(0, num_records, ["seed", "review", "complaint"])
1714+
1715+
for ri in range(num_records):
1716+
row = buffer_mgr.get_row(0, ri)
1717+
seed_val = row["seed"]
1718+
if seed_val < 2:
1719+
assert row.get("review") is None, f"row {ri}: review should be skipped (seed={seed_val})"
1720+
assert row.get("review__trace") is None, f"row {ri}: review__trace should be cleared on skip"
1721+
assert row.get("complaint") is None, f"row {ri}: complaint should propagate skip (seed={seed_val})"
1722+
else:
1723+
assert row.get("complaint") is not None, f"row {ri}: complaint should be generated (seed={seed_val})"
1724+
1725+
16431726
@pytest.mark.asyncio(loop_scope="session")
16441727
async def test_scheduler_skip_full_column_batch() -> None:
16451728
"""Full-column (batch) generator skips rows via expression gate.
16461729
16471730
Pipeline: seed(sampler) -> review(full_column, skip.when seed<2)
16481731
Only active (non-skipped) rows should be passed to the generator.
16491732
"""
1650-
from data_designer.config.base import SkipConfig
1651-
16521733
provider = _mock_provider()
16531734
num_records = 4
16541735

packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import data_designer.engine.dataset_builders.dataset_builder as builder_mod
1313
import data_designer.lazy_heavy_imports as lazy
14+
from data_designer.config.base import SkipConfig
1415
from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig, SamplerColumnConfig
1516
from data_designer.config.config_builder import DataDesignerConfigBuilder
1617
from data_designer.config.custom_column import custom_column_generator
@@ -962,6 +963,21 @@ def fn(df: pd.DataFrame) -> pd.DataFrame:
962963
return fn
963964

964965

966+
def _make_label_generator_with_side_effect(label: str, side_effect_label: str, *required: str):
967+
"""FULL_COLUMN generator that adds a column plus one side-effect column."""
968+
969+
@custom_column_generator(required_columns=list(required), side_effect_columns=[side_effect_label])
970+
def fn(df: pd.DataFrame) -> pd.DataFrame:
971+
return df.assign(
972+
**{
973+
label: f"generated_{label}",
974+
side_effect_label: f"generated_{side_effect_label}",
975+
}
976+
)
977+
978+
return fn
979+
980+
965981
def test_skip_metadata_preserved_across_non_skip_aware_full_column(
966982
stub_resource_provider, stub_model_configs, seed_data_setup
967983
):
@@ -972,8 +988,6 @@ def test_skip_metadata_preserved_across_non_skip_aware_full_column(
972988
Before the fix, summary's replace_buffer erased __internal_skipped_columns,
973989
causing complaint to generate for rows that should have been skipped.
974990
"""
975-
from data_designer.config.base import SkipConfig
976-
977991
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
978992
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
979993

@@ -1031,8 +1045,6 @@ def test_skip_metadata_preserved_when_no_rows_skipped_for_current_column(
10311045
own expression (it has none). The has_skipped=False fallthrough must still
10321046
preserve review's skip metadata so propagation works.
10331047
"""
1034-
from data_designer.config.base import SkipConfig
1035-
10361048
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
10371049
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
10381050

@@ -1069,6 +1081,53 @@ def test_skip_metadata_preserved_when_no_rows_skipped_for_current_column(
10691081
assert row["analysis"] == "generated_analysis", f"seed_id={row['seed_id']}: analysis should be generated"
10701082

10711083

1084+
def test_skip_propagation_resolves_side_effect_dependencies_in_sync_builder(
1085+
stub_resource_provider, stub_model_configs, seed_data_setup
1086+
):
1087+
"""A downstream dependency on a skipped side-effect should auto-skip.
1088+
1089+
Scenario: review(skip.when, produces review_side_effect) ->
1090+
analysis(required_columns=[review_side_effect], propagate_skip=True).
1091+
"""
1092+
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
1093+
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
1094+
1095+
config_builder.add_column(
1096+
CustomColumnConfig(
1097+
name="review",
1098+
generator_function=_make_label_generator_with_side_effect("review", "review_side_effect", "seed_id"),
1099+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1100+
skip=SkipConfig(when="{{ seed_id < 3 }}"),
1101+
)
1102+
)
1103+
config_builder.add_column(
1104+
CustomColumnConfig(
1105+
name="analysis",
1106+
generator_function=_make_label_generator("analysis", "review_side_effect"),
1107+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1108+
propagate_skip=True,
1109+
)
1110+
)
1111+
1112+
builder = DatasetBuilder(
1113+
data_designer_config=config_builder.build(),
1114+
resource_provider=stub_resource_provider,
1115+
)
1116+
result = builder.build_preview(num_records=5)
1117+
1118+
skipped_ids = {1, 2}
1119+
for _, row in result.iterrows():
1120+
if row["seed_id"] in skipped_ids:
1121+
assert row["review_side_effect"] is None or lazy.pd.isna(row["review_side_effect"]), (
1122+
f"seed_id={row['seed_id']}: review_side_effect should be cleared when review is skipped"
1123+
)
1124+
assert row["analysis"] is None or lazy.pd.isna(row["analysis"]), (
1125+
f"seed_id={row['seed_id']}: analysis should propagate skip from review"
1126+
)
1127+
else:
1128+
assert row["analysis"] == "generated_analysis", f"seed_id={row['seed_id']}: analysis should be generated"
1129+
1130+
10721131
def test_allow_resize_column_not_blocked_by_upstream_skip(stub_resource_provider, stub_model_configs, seed_data_setup):
10731132
"""An allow_resize=True column depending on a skippable upstream must not
10741133
enter the skip-aware branch (which enforces 1:1 row counts).
@@ -1077,8 +1136,6 @@ def test_allow_resize_column_not_blocked_by_upstream_skip(stub_resource_provider
10771136
with propagate_skip=True and required_columns pointing to a skippable
10781137
upstream, causing a DatasetGenerationError on the row-count check.
10791138
"""
1080-
from data_designer.config.base import SkipConfig
1081-
10821139
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
10831140
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
10841141

packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_side_effect_column_resolution() -> None:
112112

113113
assert graph.get_upstream_columns("trace_len") == {"summary"}
114114
assert graph.get_downstream_columns("summary") == {"trace_len"}
115+
assert graph.get_required_columns("trace_len") == ["summary"]
115116

116117

117118
def test_reasoning_content_side_effect() -> None:

0 commit comments

Comments
 (0)