Skip to content

Commit c81e688

Browse files
committed
Improve error message when tokenize_train/eval_data doesn't match the dataset
1 parent 211e538 commit c81e688

3 files changed

Lines changed: 82 additions & 37 deletions

File tree

src/maxtext/input_pipeline/data_processing_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def parse_and_keep_features(dataset, config, data_columns, tokenize):
2929
"""Parse arrayrecord features or keep specified columns for other formats."""
3030
if config.grain_file_type in ("arrayrecord", "tfrecord"):
3131
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
32-
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
33-
else:
34-
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))
32+
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
3533
return dataset
3634

3735

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
Features = dict[str, Any]
4040
INPUT_TOKENS_KEY = "input_ids"
4141

42+
# pylint: disable=protected-access
43+
4244
########## Functions used by TFDS pipeline
4345

4446

@@ -89,6 +91,34 @@ def _process_string(string_tensor):
8991
return features
9092

9193

94+
def validate_tfds_data_types(dataset, data_keys, tokenize):
95+
"""Validate that TFDS dataset types match the tokenization configuration."""
96+
import tensorflow as tf # pylint: disable=import-outside-toplevel
97+
98+
spec = dataset.element_spec
99+
for k in data_keys:
100+
if k not in spec:
101+
continue
102+
103+
# For TFDS, strings/bytes are usually tf.string
104+
is_string = spec[k].dtype == tf.string
105+
106+
if tokenize and not is_string:
107+
raise ValueError(
108+
f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is enabled. "
109+
"This often happens if the dataset is already tokenized (contains integers) "
110+
"but 'tokenize_train_data=True' (or 'tokenize_eval_data=True') is set in the configuration. "
111+
"If your data is already tokenized, please set it to False."
112+
)
113+
if not tokenize and is_string:
114+
raise ValueError(
115+
f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is disabled. "
116+
"This often happens if the dataset is NOT tokenized (contains strings) "
117+
"but 'tokenize_train_data=False' (or 'tokenize_eval_data=False') is set in the configuration. "
118+
"If your data is not tokenized, please set it to True."
119+
)
120+
121+
92122
########## Functions used by HF pipeline
93123

94124

@@ -544,7 +574,7 @@ def make_tfrecord_iter_dataset(path: str):
544574

545575
@dataclasses.dataclass
546576
class ParseFeatures(grain.MapTransform):
547-
"""Parse serialized example"""
577+
"""Parse serialized example proto into a dictionary of arrays."""
548578

549579
def __init__(self, data_columns, tokenize):
550580
self.data_columns = data_columns
@@ -575,45 +605,60 @@ def map(self, element):
575605

576606
@dataclasses.dataclass
577607
class NormalizeFeatures(grain.MapTransform):
578-
"""Normalize text feature keys."""
608+
"""Universal feature normalizer and validator. Acts as selector and validator."""
579609

580610
def __init__(self, column_names, tokenize):
581611
self.column_names = column_names
582612
self.tokenize = tokenize
583613

584614
def map(self, element):
585-
if self.tokenize:
586-
return {col: element[col][0].decode() for col in self.column_names}
587-
else:
588-
return {col: element[col] for col in self.column_names}
589-
590-
591-
@dataclasses.dataclass
592-
class KeepFeatures(grain.MapTransform):
593-
"""Keep only specified features in the dataset element.
594-
595-
This transform filters the input dictionary, retaining only the keys
596-
that are present in `feature_names`.
597-
"""
598-
599-
def __init__(self, feature_names: list[str]):
600-
"""Initializes the KeepFeatures transform.
601-
602-
Args:
603-
feature_names: A list of strings, where each string is the name of a
604-
feature to be kept in the dataset element.
605-
"""
606-
self.feature_names = feature_names
607-
608-
def map(self, element: dict[str, Any]) -> dict[str, Any]:
609-
"""Applies the feature filtering to the input element."""
610-
missing = [n for n in self.feature_names if n not in element]
611-
if missing:
612-
raise ValueError(
613-
f"Column {missing} not found in dataset. Available columns: {sorted(element.keys())}. "
614-
"Please set train_data_columns or eval_data_columns accordingly."
615-
)
616-
return {k: v for k, v in element.items() if k in self.feature_names}
615+
res = {}
616+
for col in self.column_names:
617+
val = element[col]
618+
if self.tokenize:
619+
# ArrayRecord/TFRecord: ParseFeatures wraps bytes_list.value in np.ndarray(dtype=object).
620+
# An empty array means the proto stored data in int64_list (already tokenized) — user config error.
621+
if isinstance(val, (list, np.ndarray)):
622+
if len(val) == 0:
623+
raise ValueError(
624+
f"Expected non-empty string/bytes list for column '{col}' because tokenization is enabled, "
625+
"but got an empty list. This often happens if the dataset is already tokenized (contains integers) "
626+
"but tokenization is enabled in the configuration. "
627+
"If your data is already tokenized, please set 'tokenize_train_data=False' "
628+
"(or 'tokenize_eval_data=False')."
629+
)
630+
val = val[0] # unwrap the single-element array from ParseFeatures
631+
632+
# ArrayRecord/TFRecord: proto bytes_list values are Python bytes after unwrapping above.
633+
if isinstance(val, bytes):
634+
val = val.decode("utf-8")
635+
636+
# Parquet: string columns arrive as scalar str (no unwrapping needed).
637+
# Any other type indicates a misconfiguration (e.g. already-tokenized integers).
638+
if not isinstance(val, str):
639+
raise ValueError(
640+
f"Expected string or bytes for column '{col}' but got {type(val)} with value {val}. "
641+
"If your data is already tokenized, please set 'tokenize_train_data=False' "
642+
"(or 'tokenize_eval_data=False') in the configuration."
643+
)
644+
res[col] = val
645+
else:
646+
# Parquet: text column arrives as scalar str/bytes — user forgot to pre-tokenize.
647+
if isinstance(val, (str, bytes)):
648+
raise ValueError(
649+
f"Expected tokenized integers for column '{col}' because tokenization is disabled, "
650+
f"but got strings. If your data is NOT tokenized, please set 'tokenize_train_data=True' "
651+
"(or 'tokenize_eval_data=True') in the configuration."
652+
)
653+
# ArrayRecord/TFRecord: ParseFeatures reads from int64_list; an empty array means the data
654+
# was stored in bytes_list (not tokenized) — user forgot to enable tokenization.
655+
if isinstance(val, (list, np.ndarray)) and len(val) == 0:
656+
raise ValueError(
657+
f"Column '{col}' is empty. This often happens if the dataset contains strings "
658+
"but tokenization is disabled (looking for integers)."
659+
)
660+
res[col] = val
661+
return res
617662

618663

619664
@dataclasses.dataclass

src/maxtext/input_pipeline/tfds_data_processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def preprocessing_pipeline(
102102
"Please set train_data_columns or eval_data_columns accordingly."
103103
)
104104

105+
input_pipeline_utils.validate_tfds_data_types(dataset, data_column_names, tokenize)
106+
105107
if not use_dpo:
106108
assert len(data_column_names) == 1
107109
dataset = dataset.map(

0 commit comments

Comments
 (0)