Skip to content

Commit 211e538

Browse files
Merge pull request #3808 from AI-Hypercomputer:aireen/grain_error
PiperOrigin-RevId: 910999131
2 parents d029d47 + 29371f8 commit 211e538

2 files changed

Lines changed: 25 additions & 6 deletions

File tree

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -556,14 +556,20 @@ def map(self, element):
556556
example.ParseFromString(element)
557557
features = example.features.feature
558558

559+
missing = [c for c in self.data_columns if c not in features]
560+
if missing:
561+
raise ValueError(
562+
f"Column {missing} not found in dataset. Available columns: {sorted(features.keys())}. "
563+
"Please set train_data_columns or eval_data_columns accordingly."
564+
)
565+
559566
parsed = {}
560567
for col in self.data_columns:
561-
if col in features:
562-
f = features[col]
563-
if self.tokenize:
564-
parsed[col] = np.array(f.bytes_list.value, dtype=object)
565-
else:
566-
parsed[col] = np.array(f.int64_list.value, dtype=np.int32)
568+
f = features[col]
569+
if self.tokenize:
570+
parsed[col] = np.array(f.bytes_list.value, dtype=object)
571+
else:
572+
parsed[col] = np.array(f.int64_list.value, dtype=np.int32)
567573
return parsed
568574

569575

@@ -601,6 +607,12 @@ def __init__(self, feature_names: list[str]):
601607

602608
def map(self, element: dict[str, Any]) -> dict[str, Any]:
603609
"""Applies the feature filtering to the input element."""
610+
missing = [n for n in self.feature_names if n not in element]
611+
if missing:
612+
raise ValueError(
613+
f"Column {missing} not found in dataset. Available columns: {sorted(element.keys())}. "
614+
"Please set train_data_columns or eval_data_columns accordingly."
615+
)
604616
return {k: v for k, v in element.items() if k in self.feature_names}
605617

606618

src/maxtext/input_pipeline/tfds_data_processing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ def preprocessing_pipeline(
9595
hf_access_token: str = "",
9696
):
9797
"""pipeline for preprocessing TFDS dataset."""
98+
missing = [c for c in data_column_names if c not in dataset.element_spec]
99+
if missing:
100+
raise ValueError(
101+
f"Column {missing} not found in dataset. Available columns: {sorted(dataset.element_spec.keys())}. "
102+
"Please set train_data_columns or eval_data_columns accordingly."
103+
)
104+
98105
if not use_dpo:
99106
assert len(data_column_names) == 1
100107
dataset = dataset.map(

0 commit comments

Comments
 (0)