Skip to content

Commit 8931b45

Browse files
committed
Fix issue with full column generating messing up order of skipped rows
1 parent 4f4c461 commit 8931b45

File tree

4 files changed

+207
-28
lines changed

4 files changed

+207
-28
lines changed

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from data_designer.engine.dataset_builders.utils.skip_tracker import (
4343
SKIPPED_COLUMNS_RECORD_KEY,
4444
apply_skip_to_record,
45+
prepare_records_for_skip_metadata_round_trip,
4546
restore_skip_metadata,
4647
strip_skip_metadata_from_records,
4748
)
@@ -579,11 +580,19 @@ def _run_full_column_generator_without_skip(self, generator: ColumnGenerator) ->
579580
original_count = self.batch_manager.num_records_in_buffer
580581
allow_resize = generator.config.allow_resize if not isinstance(generator.config, MultiColumnConfig) else False
581582
old_records = [record for _, record in self.batch_manager.iter_current_batch()]
583+
input_records, restore_context = prepare_records_for_skip_metadata_round_trip(old_records)
582584

583-
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
585+
df = generator.generate(lazy.pd.DataFrame(input_records))
584586
self._log_resize_if_changed(self._column_display_name(generator.config), original_count, len(df), allow_resize)
585587
new_records = df.to_dict(orient="records")
586-
restore_skip_metadata(old_records, new_records)
588+
if restore_context is not None:
589+
try:
590+
restore_skip_metadata(new_records, context=restore_context, allow_resize=allow_resize)
591+
except ValueError as exc:
592+
raise DatasetGenerationError(
593+
f"Unable to restore skip provenance after FULL_COLUMN generation for "
594+
f"{self._column_display_name(generator.config)}: {exc}"
595+
) from exc
587596
self.batch_manager.replace_buffer(new_records, allow_resize=allow_resize)
588597

589598
def _run_full_column_generator_with_skip(self, generator: ColumnGenerator, column_name: str) -> None:

packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,20 @@
1010
from __future__ import annotations
1111

1212
from collections.abc import Sequence
13+
from dataclasses import dataclass
1314
from typing import Final
1415

1516
SKIPPED_COLUMNS_RECORD_KEY: Final[str] = "__internal_skipped_columns"
17+
SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX: Final[str] = "__internal_skip_restore_id"
18+
19+
20+
@dataclass(frozen=True, slots=True)
21+
class SkipMetadataRestoreContext:
22+
"""Metadata needed to restore skip provenance after a DataFrame round-trip."""
23+
24+
restore_id_column: str
25+
source_ids: set[str]
26+
skipped_columns_by_source_id: dict[str, set[str]]
1627

1728

1829
def apply_skip_to_record(
@@ -46,16 +57,82 @@ def strip_skip_metadata_from_records(records: Sequence[dict]) -> list[dict]:
4657
return [strip_skip_metadata_for_dataframe_row(r) for r in records]
4758

4859

49-
def restore_skip_metadata(old_records: Sequence[dict], new_records: Sequence[dict]) -> None:
50-
"""Copy ``SKIPPED_COLUMNS_RECORD_KEY`` from *old_records* into *new_records* in-place.
60+
def prepare_records_for_skip_metadata_round_trip(
61+
records: Sequence[dict],
62+
) -> tuple[list[dict], SkipMetadataRestoreContext | None]:
63+
"""Prepare records for a DataFrame round-trip while preserving skip metadata.
5164
52-
``pd.DataFrame`` construction drops non-column keys, so skip metadata is
53-
lost when records round-trip through a DataFrame. Call this after
54-
``df.to_dict(orient="records")`` to re-attach the metadata before passing
55-
the records to ``replace_buffer``. When lengths differ (e.g.
56-
``allow_resize``), only positionally matched rows are restored.
65+
Returns stripped records ready for ``pd.DataFrame(...)``. If any record has
66+
skip metadata, injects a hidden restore-ID column and returns a context that
67+
can later be passed to :func:`restore_skip_metadata`.
5768
"""
58-
for i in range(min(len(old_records), len(new_records))):
59-
meta = old_records[i].get(SKIPPED_COLUMNS_RECORD_KEY)
69+
if not any(SKIPPED_COLUMNS_RECORD_KEY in record for record in records):
70+
return strip_skip_metadata_from_records(records), None
71+
72+
restore_id_column = _choose_restore_id_column(records)
73+
prepared_records: list[dict] = []
74+
source_ids: set[str] = set()
75+
skipped_columns_by_source_id: dict[str, set[str]] = {}
76+
77+
for index, record in enumerate(records):
78+
source_id = str(index)
79+
source_ids.add(source_id)
80+
prepared_record = strip_skip_metadata_for_dataframe_row(record)
81+
prepared_record[restore_id_column] = source_id
82+
prepared_records.append(prepared_record)
83+
84+
meta = record.get(SKIPPED_COLUMNS_RECORD_KEY)
6085
if meta is not None:
61-
new_records[i][SKIPPED_COLUMNS_RECORD_KEY] = meta
86+
skipped_columns_by_source_id[source_id] = set(meta)
87+
88+
return prepared_records, SkipMetadataRestoreContext(
89+
restore_id_column=restore_id_column,
90+
source_ids=source_ids,
91+
skipped_columns_by_source_id=skipped_columns_by_source_id,
92+
)
93+
94+
95+
def restore_skip_metadata(
96+
records: Sequence[dict],
97+
*,
98+
context: SkipMetadataRestoreContext,
99+
allow_resize: bool,
100+
) -> None:
101+
"""Restore skip provenance using hidden restore IDs instead of row position."""
102+
restored_source_ids: list[str] = []
103+
for record in records:
104+
if context.restore_id_column not in record:
105+
raise ValueError(
106+
f"Records returned from the DataFrame round-trip must preserve "
107+
f"the internal column {context.restore_id_column!r} so skip "
108+
"provenance can be restored."
109+
)
110+
111+
source_id = str(record.pop(context.restore_id_column))
112+
if source_id not in context.source_ids:
113+
raise ValueError(
114+
f"Record returned unknown restore ID {source_id!r}. Skip provenance "
115+
"can only be restored for rows derived from the original input."
116+
)
117+
118+
restored_source_ids.append(source_id)
119+
meta = context.skipped_columns_by_source_id.get(source_id)
120+
if meta is not None:
121+
record[SKIPPED_COLUMNS_RECORD_KEY] = set(meta)
122+
123+
if not allow_resize:
124+
if len(restored_source_ids) != len(context.source_ids) or set(restored_source_ids) != context.source_ids:
125+
raise ValueError(
126+
"Full-column generation changed the row identity mapping while "
127+
"allow_resize=False. Returned rows must preserve a 1:1 mapping "
128+
"to the original input so skip provenance can be restored."
129+
)
130+
131+
132+
def _choose_restore_id_column(records: Sequence[dict]) -> str:
133+
candidate = SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX
134+
suffix = 0
135+
while any(candidate in record for record in records):
136+
suffix += 1
137+
candidate = f"{SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX}_{suffix}"
138+
return candidate

packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,12 @@ def _resize_full_keep_first(df: pd.DataFrame) -> pd.DataFrame:
805805
return df.drop_duplicates(subset="seed_id").assign(filtered=True)
806806

807807

808+
@custom_column_generator(required_columns=["seed_id"])
809+
def _resize_full_drop_seed_one(df: pd.DataFrame) -> pd.DataFrame:
810+
"""FULL_COLUMN: drop the row with seed_id == 1."""
811+
return df[df["seed_id"] != 1].reset_index(drop=True).assign(filtered=True)
812+
813+
808814
@custom_column_generator(required_columns=["seed_id"])
809815
def _resize_cell_expand(row: dict) -> list[dict]:
810816
"""CELL_BY_CELL: one row -> two rows (doubled)."""
@@ -1128,6 +1134,49 @@ def test_skip_propagation_resolves_side_effect_dependencies_in_sync_builder(
11281134
assert row["analysis"] == "generated_analysis", f"seed_id={row['seed_id']}: analysis should be generated"
11291135

11301136

1137+
def test_skip_metadata_restore_preserves_row_identity_across_allow_resize_full_column(
1138+
stub_resource_provider, stub_model_configs, seed_data_setup
1139+
):
1140+
"""Filtering out a skipped row must not transfer its skip provenance to surviving rows."""
1141+
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
1142+
config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"])))
1143+
1144+
config_builder.add_column(
1145+
CustomColumnConfig(
1146+
name="review",
1147+
generator_function=_make_label_generator("review", "seed_id"),
1148+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1149+
skip=SkipConfig(when="{{ seed_id == 1 }}"),
1150+
)
1151+
)
1152+
config_builder.add_column(
1153+
CustomColumnConfig(
1154+
name="filtered",
1155+
generator_function=_resize_full_drop_seed_one,
1156+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1157+
allow_resize=True,
1158+
propagate_skip=False,
1159+
)
1160+
)
1161+
config_builder.add_column(
1162+
CustomColumnConfig(
1163+
name="analysis",
1164+
generator_function=_make_label_generator("analysis", "review"),
1165+
generation_strategy=GenerationStrategy.FULL_COLUMN,
1166+
propagate_skip=True,
1167+
)
1168+
)
1169+
1170+
builder = DatasetBuilder(
1171+
data_designer_config=config_builder.build(),
1172+
resource_provider=stub_resource_provider,
1173+
)
1174+
result = builder.build_preview(num_records=5)
1175+
1176+
assert result["seed_id"].tolist() == [2, 3, 4, 5]
1177+
assert result["analysis"].tolist() == ["generated_analysis"] * 4
1178+
1179+
11311180
def test_allow_resize_column_not_blocked_by_upstream_skip(stub_resource_provider, stub_model_configs, seed_data_setup):
11321181
"""An allow_resize=True column depending on a skippable upstream must not
11331182
enter the skip-aware branch (which enforces 1:1 row counts).

packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from data_designer.engine.dataset_builders.utils.skip_tracker import (
1010
SKIPPED_COLUMNS_RECORD_KEY,
1111
apply_skip_to_record,
12+
prepare_records_for_skip_metadata_round_trip,
1213
restore_skip_metadata,
1314
strip_skip_metadata_for_dataframe_row,
1415
strip_skip_metadata_from_records,
@@ -131,32 +132,75 @@ def test_strip_skip_metadata_from_records(rows: list[dict], expected: list[dict]
131132
assert strip_skip_metadata_from_records(rows) == expected
132133

133134

134-
def test_restore_skip_metadata_copies_metadata() -> None:
135-
old = [
135+
def test_prepare_records_for_skip_metadata_round_trip_without_metadata() -> None:
136+
rows = [{"a": 1}, {"a": 2}]
137+
prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(rows)
138+
assert restore_context is None
139+
assert prepared_rows == rows
140+
assert prepared_rows is not rows
141+
142+
143+
def test_prepare_records_for_skip_metadata_round_trip_injects_restore_ids() -> None:
144+
rows = [
136145
{"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}},
137146
{"a": 2},
138147
{"a": 3, SKIPPED_COLUMNS_RECORD_KEY: {"col_y", "col_z"}},
139148
]
140-
new = [{"a": 10}, {"a": 20}, {"a": 30}]
141-
restore_skip_metadata(old, new)
142-
assert new[0][SKIPPED_COLUMNS_RECORD_KEY] == {"col_x"}
143-
assert SKIPPED_COLUMNS_RECORD_KEY not in new[1]
144-
assert new[2][SKIPPED_COLUMNS_RECORD_KEY] == {"col_y", "col_z"}
149+
prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(rows)
150+
assert restore_context is not None
151+
assert SKIPPED_COLUMNS_RECORD_KEY not in prepared_rows[0]
152+
assert restore_context.restore_id_column in prepared_rows[0]
153+
assert restore_context.skipped_columns_by_source_id == {
154+
"0": {"col_x"},
155+
"2": {"col_y", "col_z"},
156+
}
145157

146158

147-
def test_restore_skip_metadata_handles_length_mismatch() -> None:
159+
def test_restore_skip_metadata_uses_restore_ids_after_reorder() -> None:
148160
old = [
149161
{"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}},
150-
{"a": 2, SKIPPED_COLUMNS_RECORD_KEY: {"col_y"}},
162+
{"a": 2},
163+
{"a": 3, SKIPPED_COLUMNS_RECORD_KEY: {"col_z"}},
151164
]
152-
new = [{"a": 10}]
153-
restore_skip_metadata(old, new)
154-
assert new[0][SKIPPED_COLUMNS_RECORD_KEY] == {"col_x"}
165+
prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old)
166+
assert restore_context is not None
167+
restore_id_column = restore_context.restore_id_column
168+
169+
new = [
170+
{"a": 30, restore_id_column: prepared_rows[2][restore_id_column]},
171+
{"a": 10, restore_id_column: prepared_rows[0][restore_id_column]},
172+
{"a": 20, restore_id_column: prepared_rows[1][restore_id_column]},
173+
]
174+
restore_skip_metadata(new, context=restore_context, allow_resize=False)
175+
176+
assert new[0][SKIPPED_COLUMNS_RECORD_KEY] == {"col_z"}
177+
assert new[1][SKIPPED_COLUMNS_RECORD_KEY] == {"col_x"}
178+
assert SKIPPED_COLUMNS_RECORD_KEY not in new[2]
155179

156180

157-
def test_restore_skip_metadata_no_metadata() -> None:
181+
def test_restore_skip_metadata_allow_resize_handles_filtered_rows() -> None:
158182
old = [{"a": 1}, {"a": 2}]
159-
new = [{"a": 10}, {"a": 20}]
160-
restore_skip_metadata(old, new)
183+
prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old)
184+
assert restore_context is None
185+
186+
old = [
187+
{"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}},
188+
{"a": 2},
189+
]
190+
prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old)
191+
assert restore_context is not None
192+
restore_id_column = restore_context.restore_id_column
193+
194+
new = [{"a": 20, restore_id_column: prepared_rows[1][restore_id_column]}]
195+
restore_skip_metadata(new, context=restore_context, allow_resize=True)
196+
161197
assert SKIPPED_COLUMNS_RECORD_KEY not in new[0]
162-
assert SKIPPED_COLUMNS_RECORD_KEY not in new[1]
198+
199+
200+
def test_restore_skip_metadata_rejects_missing_restore_id_column() -> None:
201+
old = [{"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}]
202+
_prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old)
203+
assert restore_context is not None
204+
205+
with pytest.raises(ValueError, match="must preserve the internal column"):
206+
restore_skip_metadata([{"a": 10}], context=restore_context, allow_resize=False)

0 commit comments

Comments
 (0)