|
39 | 39 | Features = dict[str, Any] |
40 | 40 | INPUT_TOKENS_KEY = "input_ids" |
41 | 41 |
|
| 42 | +# pylint: disable=protected-access |
| 43 | + |
42 | 44 | ########## Functions used by TFDS pipeline |
43 | 45 |
|
44 | 46 |
|
@@ -89,6 +91,34 @@ def _process_string(string_tensor): |
89 | 91 | return features |
90 | 92 |
|
91 | 93 |
|
| 94 | +def validate_tfds_data_types(dataset, data_keys, tokenize): |
| 95 | + """Validate that TFDS dataset types match the tokenization configuration.""" |
| 96 | + import tensorflow as tf # pylint: disable=import-outside-toplevel |
| 97 | + |
| 98 | + spec = dataset.element_spec |
| 99 | + for k in data_keys: |
| 100 | + if k not in spec: |
| 101 | + continue |
| 102 | + |
| 103 | + # For TFDS, strings/bytes are usually tf.string |
| 104 | + is_string = spec[k].dtype == tf.string |
| 105 | + |
| 106 | + if tokenize and not is_string: |
| 107 | + raise ValueError( |
| 108 | + f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is enabled. " |
| 109 | + "This often happens if the dataset is already tokenized (contains integers) " |
| 110 | + "but 'tokenize_train_data=True' (or 'tokenize_eval_data=True') is set in the configuration. " |
| 111 | + "If your data is already tokenized, please set it to False." |
| 112 | + ) |
| 113 | + if not tokenize and is_string: |
| 114 | + raise ValueError( |
| 115 | + f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is disabled. " |
| 116 | + "This often happens if the dataset is NOT tokenized (contains strings) " |
| 117 | + "but 'tokenize_train_data=False' (or 'tokenize_eval_data=False') is set in the configuration. " |
| 118 | + "If your data is not tokenized, please set it to True." |
| 119 | + ) |
| 120 | + |
| 121 | + |
92 | 122 | ########## Functions used by HF pipeline |
93 | 123 |
|
94 | 124 |
|
@@ -544,7 +574,7 @@ def make_tfrecord_iter_dataset(path: str): |
544 | 574 |
|
545 | 575 | @dataclasses.dataclass |
546 | 576 | class ParseFeatures(grain.MapTransform): |
547 | | - """Parse serialized example""" |
| 577 | + """Parse serialized example proto into a dictionary of arrays.""" |
548 | 578 |
|
549 | 579 | def __init__(self, data_columns, tokenize): |
550 | 580 | self.data_columns = data_columns |
@@ -575,45 +605,60 @@ def map(self, element): |
575 | 605 |
|
576 | 606 | @dataclasses.dataclass |
577 | 607 | class NormalizeFeatures(grain.MapTransform): |
578 | | - """Normalize text feature keys.""" |
| 608 | + """Universal feature normalizer and validator. Acts as selector and validator.""" |
579 | 609 |
|
580 | 610 | def __init__(self, column_names, tokenize): |
581 | 611 | self.column_names = column_names |
582 | 612 | self.tokenize = tokenize |
583 | 613 |
|
584 | 614 | def map(self, element): |
585 | | - if self.tokenize: |
586 | | - return {col: element[col][0].decode() for col in self.column_names} |
587 | | - else: |
588 | | - return {col: element[col] for col in self.column_names} |
589 | | - |
590 | | - |
591 | | -@dataclasses.dataclass |
592 | | -class KeepFeatures(grain.MapTransform): |
593 | | - """Keep only specified features in the dataset element. |
594 | | -
|
595 | | - This transform filters the input dictionary, retaining only the keys |
596 | | - that are present in `feature_names`. |
597 | | - """ |
598 | | - |
599 | | - def __init__(self, feature_names: list[str]): |
600 | | - """Initializes the KeepFeatures transform. |
601 | | -
|
602 | | - Args: |
603 | | - feature_names: A list of strings, where each string is the name of a |
604 | | - feature to be kept in the dataset element. |
605 | | - """ |
606 | | - self.feature_names = feature_names |
607 | | - |
608 | | - def map(self, element: dict[str, Any]) -> dict[str, Any]: |
609 | | - """Applies the feature filtering to the input element.""" |
610 | | - missing = [n for n in self.feature_names if n not in element] |
611 | | - if missing: |
612 | | - raise ValueError( |
613 | | - f"Column {missing} not found in dataset. Available columns: {sorted(element.keys())}. " |
614 | | - "Please set train_data_columns or eval_data_columns accordingly." |
615 | | - ) |
616 | | - return {k: v for k, v in element.items() if k in self.feature_names} |
| 615 | + res = {} |
| 616 | + for col in self.column_names: |
| 617 | + val = element[col] |
| 618 | + if self.tokenize: |
| 619 | + # ArrayRecord/TFRecord: ParseFeatures wraps bytes_list.value in np.ndarray(dtype=object). |
| 620 | + # An empty array means the proto stored data in int64_list (already tokenized) — user config error. |
| 621 | + if isinstance(val, (list, np.ndarray)): |
| 622 | + if len(val) == 0: |
| 623 | + raise ValueError( |
| 624 | + f"Expected non-empty string/bytes list for column '{col}' because tokenization is enabled, " |
| 625 | + "but got an empty list. This often happens if the dataset is already tokenized (contains integers) " |
| 626 | + "but tokenization is enabled in the configuration. " |
| 627 | + "If your data is already tokenized, please set 'tokenize_train_data=False' " |
| 628 | + "(or 'tokenize_eval_data=False')." |
| 629 | + ) |
| 630 | + val = val[0] # unwrap the single-element array from ParseFeatures |
| 631 | + |
| 632 | + # ArrayRecord/TFRecord: proto bytes_list values are Python bytes after unwrapping above. |
| 633 | + if isinstance(val, bytes): |
| 634 | + val = val.decode("utf-8") |
| 635 | + |
| 636 | + # Parquet: string columns arrive as scalar str (no unwrapping needed). |
| 637 | + # Any other type indicates a misconfiguration (e.g. already-tokenized integers). |
| 638 | + if not isinstance(val, str): |
| 639 | + raise ValueError( |
| 640 | + f"Expected string or bytes for column '{col}' but got {type(val)} with value {val}. " |
| 641 | + "If your data is already tokenized, please set 'tokenize_train_data=False' " |
| 642 | + "(or 'tokenize_eval_data=False') in the configuration." |
| 643 | + ) |
| 644 | + res[col] = val |
| 645 | + else: |
| 646 | + # Parquet: text column arrives as scalar str/bytes — user forgot to pre-tokenize. |
| 647 | + if isinstance(val, (str, bytes)): |
| 648 | + raise ValueError( |
| 649 | + f"Expected tokenized integers for column '{col}' because tokenization is disabled, " |
| 650 | + f"but got strings. If your data is NOT tokenized, please set 'tokenize_train_data=True' " |
| 651 | + "(or 'tokenize_eval_data=True') in the configuration." |
| 652 | + ) |
| 653 | + # ArrayRecord/TFRecord: ParseFeatures reads from int64_list; an empty array means the data |
| 654 | + # was stored in bytes_list (not tokenized) — user forgot to enable tokenization. |
| 655 | + if isinstance(val, (list, np.ndarray)) and len(val) == 0: |
| 656 | + raise ValueError( |
| 657 | + f"Column '{col}' is empty. This often happens if the dataset contains strings " |
| 658 | + "but tokenization is disabled (looking for integers)." |
| 659 | + ) |
| 660 | + res[col] = val |
| 661 | + return res |
617 | 662 |
|
618 | 663 |
|
619 | 664 | @dataclasses.dataclass |
|
0 commit comments