|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Utility functions for data processing pipelines.""" |
| 16 | + |
| 17 | +import functools |
| 18 | + |
| 19 | +import jax |
| 20 | +from grain.experimental import BestFitPackIterDataset, pick_performance_config |
| 21 | +import grain.python as grain |
| 22 | + |
| 23 | +from maxtext.input_pipeline import input_pipeline_utils |
| 24 | +from maxtext.input_pipeline import tokenizer |
| 25 | + |
| 26 | + |
| 27 | +def parse_and_keep_features(dataset, config, data_columns, tokenize): |
| 28 | + """Parse arrayrecord features or keep specified columns for other formats.""" |
| 29 | + if config.grain_file_type in ("arrayrecord", "tfrecord"): |
| 30 | + dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) |
| 31 | + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) |
| 32 | + else: |
| 33 | + dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) |
| 34 | + return dataset |
| 35 | + |
| 36 | + |
| 37 | +def get_tokenizer_and_pad_id(config): |
| 38 | + """Builds tokenizer and extracts pad_id safely.""" |
| 39 | + tokenizer_model = tokenizer.build_tokenizer( |
| 40 | + config.tokenizer_path, |
| 41 | + config.tokenizer_type, |
| 42 | + config.add_bos, |
| 43 | + config.add_eos, |
| 44 | + config.hf_access_token, |
| 45 | + ) |
| 46 | + if tokenizer_model.pad_id is not None: |
| 47 | + pad_id = tokenizer_model.pad_id |
| 48 | + elif tokenizer_model.unk_id is not None: |
| 49 | + pad_id = tokenizer_model.unk_id |
| 50 | + else: |
| 51 | + pad_id = 0 |
| 52 | + return tokenizer_model, pad_id |
| 53 | + |
| 54 | + |
| 55 | +def validate_and_configure_sft_columns(data_columns, tokenizer_model, chat_template=None): |
| 56 | + """Validates SFT data columns and configures the tokenizer chat template.""" |
| 57 | + if chat_template and hasattr(tokenizer_model, "chat_template"): |
| 58 | + tokenizer_model.chat_template = chat_template |
| 59 | + |
| 60 | + supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] |
| 61 | + assert any( |
| 62 | + set(data_columns) == set(supported) for supported in supported_columns |
| 63 | + ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}" |
| 64 | + |
| 65 | + |
| 66 | +def get_local_batch_size(config): |
| 67 | + """Computes local batch size based on process count and expansion factor.""" |
| 68 | + batch_size = config.global_batch_size_to_load // jax.process_count() |
| 69 | + if config.expansion_factor_real_data > 1: |
| 70 | + # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. |
| 71 | + # But when using Grain, we want to keep the batch_size consistent with that in the checkpoint. |
| 72 | + # We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py. |
| 73 | + batch_size = int(batch_size // config.expansion_factor_real_data) |
| 74 | + return batch_size |
| 75 | + |
| 76 | + |
| 77 | +def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model): |
| 78 | + """Packs or pads the dataset according to config and batches it.""" |
| 79 | + if config.packing: |
| 80 | + length_struct = {col: config.max_target_length for col in data_columns} |
| 81 | + max_segments = config.max_segments_per_seq |
| 82 | + if max_segments is not None and max_segments <= 0: |
| 83 | + max_segments = None |
| 84 | + if config.grain_packing_type == "first_fit": |
| 85 | + dataset = grain.experimental.FirstFitPackIterDataset( |
| 86 | + dataset, |
| 87 | + length_struct=length_struct, |
| 88 | + num_packing_bins=batch_size, |
| 89 | + max_sequences_per_bin=max_segments, |
| 90 | + ) |
| 91 | + elif config.grain_packing_type == "best_fit": |
| 92 | + dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size) |
| 93 | + elif config.grain_packing_type == "concat_then_split": |
| 94 | + if config.add_bos and hasattr(tokenizer_model, "bos_id"): |
| 95 | + dataset = grain.experimental.ConcatThenSplitIterDataset( |
| 96 | + dataset, |
| 97 | + length_struct=length_struct, |
| 98 | + bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS, |
| 99 | + bos_token_id=tokenizer_model.bos_id, |
| 100 | + ) |
| 101 | + else: |
| 102 | + dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct) |
| 103 | + else: |
| 104 | + raise ValueError(f"Unknown packing type: {config.packing}") |
| 105 | + |
| 106 | + rekey_dict = { |
| 107 | + "targets_segmentation": "targets_segment_ids", |
| 108 | + "inputs_segmentation": "inputs_segment_ids", |
| 109 | + "targets_position": "targets_positions", |
| 110 | + "inputs_position": "inputs_positions", |
| 111 | + } |
| 112 | + dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) |
| 113 | + else: |
| 114 | + dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) |
| 115 | + |
| 116 | + batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) |
| 117 | + dataset = dataset.batch(batch_size, batch_fn=batch_fn) |
| 118 | + return dataset |
| 119 | + |
| 120 | + |
| 121 | +def shift_dataset(dataset, pad_id): |
| 122 | + """Shift tokens to create inputs and targets for standard next-token prediction.""" |
| 123 | + return dataset.map( |
| 124 | + input_pipeline_utils.ShiftData( |
| 125 | + ignored_ids=[pad_id], |
| 126 | + axis=1, |
| 127 | + ) |
| 128 | + ) |
| 129 | + |
| 130 | + |
| 131 | +def apply_multiprocessing_and_prefetch(dataset, config, grain_worker_count, grain_per_worker_buffer_size): |
| 132 | + """Applies multiprocessing and prefetching configurations to the dataset.""" |
| 133 | + multiprocessing_options = ( |
| 134 | + pick_performance_config( |
| 135 | + ds=dataset, |
| 136 | + ram_budget_mb=config.grain_ram_budget_mb, |
| 137 | + max_workers=None, |
| 138 | + max_buffer_size=None, |
| 139 | + ).multiprocessing_options |
| 140 | + if grain_worker_count == -1 |
| 141 | + else grain.MultiprocessingOptions( |
| 142 | + num_workers=grain_worker_count, |
| 143 | + per_worker_buffer_size=grain_per_worker_buffer_size, |
| 144 | + ) |
| 145 | + ) |
| 146 | + return dataset.mp_prefetch(multiprocessing_options) |
0 commit comments