Skip to content

Commit 4abf1b5

Browse files
committed
Improve error message when tokenize_train/eval_data doesn't match the dataset
1 parent 211e538 commit 4abf1b5

3 files changed

Lines changed: 80 additions & 37 deletions

File tree

src/maxtext/input_pipeline/data_processing_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def parse_and_keep_features(dataset, config, data_columns, tokenize):
2929
"""Parse arrayrecord features or keep specified columns for other formats."""
3030
if config.grain_file_type in ("arrayrecord", "tfrecord"):
3131
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
32-
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
33-
else:
34-
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))
32+
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
3533
return dataset
3634

3735

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 77 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,34 @@ def _process_string(string_tensor):
8989
return features
9090

9191

92+
def validate_tfds_data_types(dataset, data_keys, tokenize):
93+
"""Validate that TFDS dataset types match the tokenization configuration."""
94+
import tensorflow as tf # pylint: disable=import-outside-toplevel
95+
96+
spec = dataset.element_spec
97+
for k in data_keys:
98+
if k not in spec:
99+
continue
100+
101+
# For TFDS, strings/bytes are usually tf.string
102+
is_string = spec[k].dtype == tf.string
103+
104+
if tokenize and not is_string:
105+
raise ValueError(
106+
f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is enabled. "
107+
"This often happens if the dataset is already tokenized (contains integers) "
108+
"but 'tokenize_train_data=True' (or 'tokenize_eval_data=True') is set in the configuration. "
109+
"If your data is already tokenized, please set it to False."
110+
)
111+
if not tokenize and is_string:
112+
raise ValueError(
113+
f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is disabled. "
114+
"This often happens if the dataset is NOT tokenized (contains strings) "
115+
"but 'tokenize_train_data=False' (or 'tokenize_eval_data=False') is set in the configuration. "
116+
"If your data is not tokenized, please set it to True."
117+
)
118+
119+
92120
########## Functions used by HF pipeline
93121

94122

@@ -544,7 +572,7 @@ def make_tfrecord_iter_dataset(path: str):
544572

545573
@dataclasses.dataclass
546574
class ParseFeatures(grain.MapTransform):
547-
"""Parse serialized example"""
575+
"""Parse serialized example proto into a dictionary of arrays."""
548576

549577
def __init__(self, data_columns, tokenize):
550578
self.data_columns = data_columns
@@ -575,45 +603,60 @@ def map(self, element):
575603

576604
@dataclasses.dataclass
577605
class NormalizeFeatures(grain.MapTransform):
578-
"""Normalize text feature keys."""
606+
"""Universal feature normalizer and validator. Acts as selector and validator."""
579607

580608
def __init__(self, column_names, tokenize):
581609
self.column_names = column_names
582610
self.tokenize = tokenize
583611

584612
def map(self, element):
585-
if self.tokenize:
586-
return {col: element[col][0].decode() for col in self.column_names}
587-
else:
588-
return {col: element[col] for col in self.column_names}
589-
590-
591-
@dataclasses.dataclass
592-
class KeepFeatures(grain.MapTransform):
593-
"""Keep only specified features in the dataset element.
594-
595-
This transform filters the input dictionary, retaining only the keys
596-
that are present in `feature_names`.
597-
"""
598-
599-
def __init__(self, feature_names: list[str]):
600-
"""Initializes the KeepFeatures transform.
601-
602-
Args:
603-
feature_names: A list of strings, where each string is the name of a
604-
feature to be kept in the dataset element.
605-
"""
606-
self.feature_names = feature_names
607-
608-
def map(self, element: dict[str, Any]) -> dict[str, Any]:
609-
"""Applies the feature filtering to the input element."""
610-
missing = [n for n in self.feature_names if n not in element]
611-
if missing:
612-
raise ValueError(
613-
f"Column {missing} not found in dataset. Available columns: {sorted(element.keys())}. "
614-
"Please set train_data_columns or eval_data_columns accordingly."
615-
)
616-
return {k: v for k, v in element.items() if k in self.feature_names}
613+
res = {}
614+
for col in self.column_names:
615+
val = element[col]
616+
if self.tokenize:
617+
# ArrayRecord/TFRecord: ParseFeatures wraps bytes_list.value in np.ndarray(dtype=object).
618+
# An empty array means the proto stored data in int64_list (already tokenized) — user config error.
619+
if isinstance(val, (list, np.ndarray)):
620+
if len(val) == 0:
621+
raise ValueError(
622+
f"Expected non-empty string/bytes list for column '{col}' because tokenization is enabled, "
623+
"but got an empty list. This often happens if the dataset is already tokenized (contains integers) "
624+
"but tokenization is enabled in the configuration. "
625+
"If your data is already tokenized, please set 'tokenize_train_data=False' "
626+
"(or 'tokenize_eval_data=False')."
627+
)
628+
val = val[0] # unwrap the single-element array from ParseFeatures
629+
630+
# ArrayRecord/TFRecord: proto bytes_list values are Python bytes after unwrapping above.
631+
if isinstance(val, bytes):
632+
val = val.decode("utf-8")
633+
634+
# Parquet: string columns arrive as scalar str (no unwrapping needed).
635+
# Any other type indicates a misconfiguration (e.g. already-tokenized integers).
636+
if not isinstance(val, str):
637+
raise ValueError(
638+
f"Expected string or bytes for column '{col}' but got {type(val)} with value {val}. "
639+
"If your data is already tokenized, please set 'tokenize_train_data=False' "
640+
"(or 'tokenize_eval_data=False') in the configuration."
641+
)
642+
res[col] = val
643+
else:
644+
# Parquet: text column arrives as scalar str/bytes — user forgot to pre-tokenize.
645+
if isinstance(val, (str, bytes)):
646+
raise ValueError(
647+
f"Expected tokenized integers for column '{col}' because tokenization is disabled, "
648+
f"but got strings. If your data is NOT tokenized, please set 'tokenize_train_data=True' "
649+
"(or 'tokenize_eval_data=True') in the configuration."
650+
)
651+
# ArrayRecord/TFRecord: ParseFeatures reads from int64_list; an empty array means the data
652+
# was stored in bytes_list (not tokenized) — user forgot to enable tokenization.
653+
if isinstance(val, (list, np.ndarray)) and len(val) == 0:
654+
raise ValueError(
655+
f"Column '{col}' is empty. This often happens if the dataset contains strings "
656+
"but tokenization is disabled (looking for integers)."
657+
)
658+
res[col] = val
659+
return res
617660

618661

619662
@dataclasses.dataclass

src/maxtext/input_pipeline/tfds_data_processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def preprocessing_pipeline(
102102
"Please set train_data_columns or eval_data_columns accordingly."
103103
)
104104

105+
input_pipeline_utils.validate_tfds_data_types(dataset, data_column_names, tokenize)
106+
105107
if not use_dpo:
106108
assert len(data_column_names) == 1
107109
dataset = dataset.map(

0 commit comments

Comments
 (0)