Skip to content

Commit 4bc63b8

Browse files
committed
add skip conditional generation edge case tests
- test_skip_evaluator: parametrized should_skip_column_for_record covering propagation, expression gates, short-circuiting, and disabled propagation - test_execution_graph: skip metadata accessors (get_skip_config, should_propagate_skip, get_required_columns, get_side_effect_columns, resolve_side_effect, skip.when DAG edges) - test_dataset_builder: chained transitive propagation (4 levels), two independent skip gates, custom skip.value, row count preservation Made-with: Cursor
1 parent 8931b45 commit 4bc63b8

File tree

3 files changed

+381
-0
lines changed

3 files changed

+381
-0
lines changed

packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,3 +1211,188 @@ def test_allow_resize_column_not_blocked_by_upstream_skip(stub_resource_provider
12111211
)
12121212
result = builder.build_preview(num_records=5)
12131213
assert len(result) == 10
1214+
1215+
1216+
def test_skip_chained_transitive_propagation_through_three_levels(
1217+
stub_resource_provider, stub_model_configs, seed_data_setup
1218+
) -> None:
1219+
"""Skip at level 1 must propagate transitively through levels 2, 3, and 4.
1220+
1221+
Pipeline: seed_id(seed) -> L1(skip.when) -> L2(propagate) -> L3(propagate) -> L4(propagate)
1222+
"""
1223+
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
1224+
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
1225+
1226+
config_builder.add_column(
1227+
CustomColumnConfig(
1228+
name="L1",
1229+
generator_function=_make_label_generator("L1", "seed_id"),
1230+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1231+
skip=SkipConfig(when="{{ seed_id < 3 }}"),
1232+
)
1233+
)
1234+
config_builder.add_column(
1235+
CustomColumnConfig(
1236+
name="L2",
1237+
generator_function=_make_label_generator("L2", "L1"),
1238+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1239+
propagate_skip=True,
1240+
)
1241+
)
1242+
config_builder.add_column(
1243+
CustomColumnConfig(
1244+
name="L3",
1245+
generator_function=_make_label_generator("L3", "L2"),
1246+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1247+
propagate_skip=True,
1248+
)
1249+
)
1250+
config_builder.add_column(
1251+
CustomColumnConfig(
1252+
name="L4",
1253+
generator_function=_make_label_generator("L4", "L3"),
1254+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1255+
propagate_skip=True,
1256+
)
1257+
)
1258+
1259+
builder = DatasetBuilder(
1260+
data_designer_config=config_builder.build(),
1261+
resource_provider=stub_resource_provider,
1262+
)
1263+
result = builder.build_preview(num_records=5)
1264+
1265+
assert len(result) == 5
1266+
skipped_ids = {1, 2}
1267+
for _, row in result.iterrows():
1268+
if row["seed_id"] in skipped_ids:
1269+
for col in ("L1", "L2", "L3", "L4"):
1270+
assert row[col] is None or lazy.pd.isna(row[col]), (
1271+
f"seed_id={row['seed_id']}: {col} should be skipped transitively"
1272+
)
1273+
else:
1274+
for col in ("L1", "L2", "L3", "L4"):
1275+
assert row[col] == f"generated_{col}", f"seed_id={row['seed_id']}: {col} should be generated"
1276+
1277+
1278+
def test_skip_two_independent_gates_in_same_pipeline(
1279+
stub_resource_provider, stub_model_configs, seed_data_setup
1280+
) -> None:
1281+
"""Two columns with independent skip.when expressions; downstream propagates from both.
1282+
1283+
Pipeline: seed_id(seed) -> gate_a(skip seed_id<3) -> gate_b(skip seed_id>4) -> merge(propagate)
1284+
merge should be skipped when *either* gate_a or gate_b was skipped.
1285+
"""
1286+
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
1287+
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
1288+
1289+
config_builder.add_column(
1290+
CustomColumnConfig(
1291+
name="gate_a",
1292+
generator_function=_make_label_generator("gate_a", "seed_id"),
1293+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1294+
skip=SkipConfig(when="{{ seed_id < 3 }}"),
1295+
)
1296+
)
1297+
config_builder.add_column(
1298+
CustomColumnConfig(
1299+
name="gate_b",
1300+
generator_function=_make_label_generator("gate_b", "seed_id"),
1301+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1302+
skip=SkipConfig(when="{{ seed_id > 4 }}"),
1303+
)
1304+
)
1305+
config_builder.add_column(
1306+
CustomColumnConfig(
1307+
name="merge",
1308+
generator_function=_make_label_generator("merge", "gate_a", "gate_b"),
1309+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1310+
propagate_skip=True,
1311+
)
1312+
)
1313+
1314+
builder = DatasetBuilder(
1315+
data_designer_config=config_builder.build(),
1316+
resource_provider=stub_resource_provider,
1317+
)
1318+
result = builder.build_preview(num_records=5)
1319+
1320+
assert len(result) == 5
1321+
for _, row in result.iterrows():
1322+
sid = row["seed_id"]
1323+
if sid < 3 or sid > 4:
1324+
assert row["merge"] is None or lazy.pd.isna(row["merge"]), (
1325+
f"seed_id={sid}: merge should be skipped (gate_a or gate_b skipped)"
1326+
)
1327+
else:
1328+
assert row["merge"] == "generated_merge", f"seed_id={sid}: merge should be generated"
1329+
1330+
1331+
def test_skip_custom_value_preserved_in_output(stub_resource_provider, stub_model_configs, seed_data_setup) -> None:
1332+
"""Custom skip.value should appear in the final DataFrame instead of None."""
1333+
sentinel = "__SKIPPED__"
1334+
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
1335+
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
1336+
1337+
config_builder.add_column(
1338+
CustomColumnConfig(
1339+
name="review",
1340+
generator_function=_make_label_generator("review", "seed_id"),
1341+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1342+
skip=SkipConfig(when="{{ seed_id < 3 }}", value=sentinel),
1343+
)
1344+
)
1345+
1346+
builder = DatasetBuilder(
1347+
data_designer_config=config_builder.build(),
1348+
resource_provider=stub_resource_provider,
1349+
)
1350+
result = builder.build_preview(num_records=5)
1351+
1352+
assert len(result) == 5
1353+
skipped_ids = {1, 2}
1354+
for _, row in result.iterrows():
1355+
if row["seed_id"] in skipped_ids:
1356+
assert row["review"] == sentinel, f"seed_id={row['seed_id']}: review should have custom skip value"
1357+
else:
1358+
assert row["review"] == "generated_review", f"seed_id={row['seed_id']}: review should be generated"
1359+
1360+
1361+
def test_skip_row_count_preserved_across_pipeline(stub_resource_provider, stub_model_configs, seed_data_setup) -> None:
1362+
"""Skip must never change the row count — all 5 seed rows must survive."""
1363+
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
1364+
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
1365+
1366+
config_builder.add_column(
1367+
CustomColumnConfig(
1368+
name="review",
1369+
generator_function=_make_label_generator("review", "seed_id"),
1370+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1371+
skip=SkipConfig(when="{{ seed_id < 3 }}"),
1372+
)
1373+
)
1374+
config_builder.add_column(
1375+
CustomColumnConfig(
1376+
name="analysis",
1377+
generator_function=_make_label_generator("analysis", "review"),
1378+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1379+
propagate_skip=True,
1380+
)
1381+
)
1382+
config_builder.add_column(
1383+
CustomColumnConfig(
1384+
name="summary",
1385+
generator_function=_make_label_generator("summary", "analysis"),
1386+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1387+
propagate_skip=True,
1388+
)
1389+
)
1390+
1391+
builder = DatasetBuilder(
1392+
data_designer_config=config_builder.build(),
1393+
resource_provider=stub_resource_provider,
1394+
)
1395+
result = builder.build_preview(num_records=5)
1396+
1397+
assert len(result) == 5, "Skip must not change the row count"
1398+
assert result["seed_id"].tolist() == [1, 2, 3, 4, 5]

packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77

8+
from data_designer.config.base import SkipConfig
89
from data_designer.config.column_configs import (
910
ExpressionColumnConfig,
1011
GenerationStrategy,
@@ -449,3 +450,104 @@ def test_judge_column_dependency() -> None:
449450
graph = ExecutionGraph.create(configs, strategies)
450451

451452
assert graph.get_upstream_columns("judge") == {"text"}
453+
454+
455+
# -- Skip metadata accessors ------------------------------------------------
456+
457+
458+
def _build_skip_pipeline_graph() -> ExecutionGraph:
459+
"""gate(sampler) -> review(skip.when, with_trace) -> analysis(propagate) -> summary(no propagate)."""
460+
configs = [
461+
SamplerColumnConfig(name="gate", sampler_type=SamplerType.CATEGORY, params={"values": [0, 1]}),
462+
LLMTextColumnConfig(
463+
name="review",
464+
prompt="{{ gate }}",
465+
model_alias=MODEL_ALIAS,
466+
with_trace="last_message",
467+
skip=SkipConfig(when="{{ gate == 0 }}"),
468+
),
469+
LLMTextColumnConfig(
470+
name="analysis",
471+
prompt="{{ review }}",
472+
model_alias=MODEL_ALIAS,
473+
propagate_skip=True,
474+
),
475+
LLMTextColumnConfig(
476+
name="summary",
477+
prompt="{{ analysis }}",
478+
model_alias=MODEL_ALIAS,
479+
propagate_skip=False,
480+
),
481+
]
482+
strategies = {
483+
"gate": GenerationStrategy.FULL_COLUMN,
484+
"review": GenerationStrategy.CELL_BY_CELL,
485+
"analysis": GenerationStrategy.CELL_BY_CELL,
486+
"summary": GenerationStrategy.CELL_BY_CELL,
487+
}
488+
return ExecutionGraph.create(configs, strategies)
489+
490+
491+
def test_skip_config_returned_for_gated_column() -> None:
492+
graph = _build_skip_pipeline_graph()
493+
skip_cfg = graph.get_skip_config("review")
494+
assert skip_cfg is not None
495+
assert skip_cfg.when == "{{ gate == 0 }}"
496+
497+
498+
def test_skip_config_returns_none_for_ungated_column() -> None:
499+
graph = _build_skip_pipeline_graph()
500+
assert graph.get_skip_config("gate") is None
501+
assert graph.get_skip_config("analysis") is None
502+
503+
504+
def test_should_propagate_skip_explicit_values() -> None:
505+
graph = _build_skip_pipeline_graph()
506+
assert graph.should_propagate_skip("analysis") is True
507+
assert graph.should_propagate_skip("summary") is False
508+
509+
510+
def test_should_propagate_skip_defaults_true() -> None:
511+
graph = _build_skip_pipeline_graph()
512+
assert graph.should_propagate_skip("gate") is True
513+
assert graph.should_propagate_skip("review") is True
514+
515+
516+
def test_get_required_columns_for_skip_pipeline() -> None:
517+
graph = _build_skip_pipeline_graph()
518+
assert graph.get_required_columns("review") == ["gate"]
519+
assert graph.get_required_columns("analysis") == ["review"]
520+
assert graph.get_required_columns("summary") == ["analysis"]
521+
522+
523+
def test_get_side_effect_columns_for_skip_pipeline() -> None:
524+
graph = _build_skip_pipeline_graph()
525+
assert graph.get_side_effect_columns("review") == ["review__trace"]
526+
assert graph.get_side_effect_columns("analysis") == []
527+
528+
529+
def test_side_effect_dependency_resolves_to_producer() -> None:
530+
graph = _build_skip_pipeline_graph()
531+
assert graph.resolve_side_effect("review__trace") == "review"
532+
533+
534+
def test_skip_when_columns_create_dag_edges() -> None:
535+
"""skip.when referencing a column should create an edge in the DAG."""
536+
configs = [
537+
SamplerColumnConfig(name="gate", sampler_type=SamplerType.CATEGORY, params={"values": [0, 1]}),
538+
SamplerColumnConfig(name="data", sampler_type=SamplerType.CATEGORY, params={"values": ["x"]}),
539+
LLMTextColumnConfig(
540+
name="output",
541+
prompt="{{ data }}",
542+
model_alias=MODEL_ALIAS,
543+
skip=SkipConfig(when="{{ gate == 0 }}"),
544+
),
545+
]
546+
strategies = {
547+
"gate": GenerationStrategy.FULL_COLUMN,
548+
"data": GenerationStrategy.FULL_COLUMN,
549+
"output": GenerationStrategy.CELL_BY_CELL,
550+
}
551+
graph = ExecutionGraph.create(configs, strategies)
552+
assert "gate" in graph.get_upstream_columns("output")
553+
assert "data" in graph.get_upstream_columns("output")

0 commit comments

Comments
 (0)