@@ -89,6 +89,34 @@ def _process_string(string_tensor):
8989 return features
9090
9191
92+ def validate_tfds_data_types (dataset , data_keys , tokenize ):
93+ """Validate that TFDS dataset types match the tokenization configuration."""
94+ import tensorflow as tf # pylint: disable=import-outside-toplevel
95+
96+ spec = dataset .element_spec
97+ for k in data_keys :
98+ if k not in spec :
99+ continue
100+
101+ # For TFDS, strings/bytes are usually tf.string
102+ is_string = spec [k ].dtype == tf .string
103+
104+ if tokenize and not is_string :
105+ raise ValueError (
106+ f"TFDS dataset column '{ k } ' has type { spec [k ].dtype } , but tokenization is enabled. "
107+ "This often happens if the dataset is already tokenized (contains integers) "
108+ "but 'tokenize_train_data=True' (or 'tokenize_eval_data=True') is set in the configuration. "
109+ "If your data is already tokenized, please set it to False."
110+ )
111+ if not tokenize and is_string :
112+ raise ValueError (
113+ f"TFDS dataset column '{ k } ' has type { spec [k ].dtype } , but tokenization is disabled. "
114+ "This often happens if the dataset is NOT tokenized (contains strings) "
115+ "but 'tokenize_train_data=False' (or 'tokenize_eval_data=False') is set in the configuration. "
116+ "If your data is not tokenized, please set it to True."
117+ )
118+
119+
92120########## Functions used by HF pipeline
93121
94122
@@ -544,7 +572,7 @@ def make_tfrecord_iter_dataset(path: str):
544572
545573@dataclasses .dataclass
546574class ParseFeatures (grain .MapTransform ):
547- """Parse serialized example"""
575+ """Parse serialized example proto into a dictionary of arrays. """
548576
549577 def __init__ (self , data_columns , tokenize ):
550578 self .data_columns = data_columns
@@ -575,45 +603,60 @@ def map(self, element):
575603
576604@dataclasses .dataclass
577605class NormalizeFeatures (grain .MapTransform ):
578- """Normalize text feature keys ."""
606+ """Universal feature normalizer and validator. Acts as selector and validator ."""
579607
580608 def __init__ (self , column_names , tokenize ):
581609 self .column_names = column_names
582610 self .tokenize = tokenize
583611
584612 def map (self , element ):
585- if self .tokenize :
586- return {col : element [col ][0 ].decode () for col in self .column_names }
587- else :
588- return {col : element [col ] for col in self .column_names }
589-
590-
591- @dataclasses .dataclass
592- class KeepFeatures (grain .MapTransform ):
593- """Keep only specified features in the dataset element.
594-
595- This transform filters the input dictionary, retaining only the keys
596- that are present in `feature_names`.
597- """
598-
599- def __init__ (self , feature_names : list [str ]):
600- """Initializes the KeepFeatures transform.
601-
602- Args:
603- feature_names: A list of strings, where each string is the name of a
604- feature to be kept in the dataset element.
605- """
606- self .feature_names = feature_names
607-
608- def map (self , element : dict [str , Any ]) -> dict [str , Any ]:
609- """Applies the feature filtering to the input element."""
610- missing = [n for n in self .feature_names if n not in element ]
611- if missing :
612- raise ValueError (
613- f"Column { missing } not found in dataset. Available columns: { sorted (element .keys ())} . "
614- "Please set train_data_columns or eval_data_columns accordingly."
615- )
616- return {k : v for k , v in element .items () if k in self .feature_names }
613+ res = {}
614+ for col in self .column_names :
615+ val = element [col ]
616+ if self .tokenize :
617+ # ArrayRecord/TFRecord: ParseFeatures wraps bytes_list.value in np.ndarray(dtype=object).
618+ # An empty array means the proto stored data in int64_list (already tokenized) — user config error.
619+ if isinstance (val , (list , np .ndarray )):
620+ if len (val ) == 0 :
621+ raise ValueError (
622+ f"Expected non-empty string/bytes list for column '{ col } ' because tokenization is enabled, "
623+ "but got an empty list. This often happens if the dataset is already tokenized (contains integers) "
624+ "but tokenization is enabled in the configuration. "
625+ "If your data is already tokenized, please set 'tokenize_train_data=False' "
626+ "(or 'tokenize_eval_data=False')."
627+ )
628+ val = val [0 ] # unwrap the single-element array from ParseFeatures
629+
630+ # ArrayRecord/TFRecord: proto bytes_list values are Python bytes after unwrapping above.
631+ if isinstance (val , bytes ):
632+ val = val .decode ("utf-8" )
633+
634+ # Parquet: string columns arrive as scalar str (no unwrapping needed).
635+ # Any other type indicates a misconfiguration (e.g. already-tokenized integers).
636+ if not isinstance (val , str ):
637+ raise ValueError (
638+ f"Expected string or bytes for column '{ col } ' but got { type (val )} with value { val } . "
639+ "If your data is already tokenized, please set 'tokenize_train_data=False' "
640+ "(or 'tokenize_eval_data=False') in the configuration."
641+ )
642+ res [col ] = val
643+ else :
644+ # Parquet: text column arrives as scalar str/bytes — user forgot to pre-tokenize.
645+ if isinstance (val , (str , bytes )):
646+ raise ValueError (
647+ f"Expected tokenized integers for column '{ col } ' because tokenization is disabled, "
648+ f"but got strings. If your data is NOT tokenized, please set 'tokenize_train_data=True' "
649+ "(or 'tokenize_eval_data=True') in the configuration."
650+ )
651+ # ArrayRecord/TFRecord: ParseFeatures reads from int64_list; an empty array means the data
652+ # was stored in bytes_list (not tokenized) — user forgot to enable tokenization.
653+ if isinstance (val , (list , np .ndarray )) and len (val ) == 0 :
654+ raise ValueError (
655+ f"Column '{ col } ' is empty. This often happens if the dataset contains strings "
656+ "but tokenization is disabled (looking for integers)."
657+ )
658+ res [col ] = val
659+ return res
617660
618661
619662@dataclasses .dataclass
0 commit comments