diff --git a/src/maxtext/input_pipeline/data_processing_utils.py b/src/maxtext/input_pipeline/data_processing_utils.py new file mode 100644 index 0000000000..b94ffeeb6c --- /dev/null +++ b/src/maxtext/input_pipeline/data_processing_utils.py @@ -0,0 +1,146 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for data processing pipelines.""" + +import functools + +import jax +from grain.experimental import BestFitPackIterDataset, pick_performance_config +import grain.python as grain + +from maxtext.input_pipeline import input_pipeline_utils +from maxtext.input_pipeline import tokenizer + + +def parse_and_keep_features(dataset, config, data_columns, tokenize): + """Parse arrayrecord features or keep specified columns for other formats.""" + if config.grain_file_type in ("arrayrecord", "tfrecord"): + dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) + else: + dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) + return dataset + + +def get_tokenizer_and_pad_id(config): + """Builds tokenizer and extracts pad_id safely.""" + tokenizer_model = tokenizer.build_tokenizer( + config.tokenizer_path, + config.tokenizer_type, + config.add_bos, + config.add_eos, + config.hf_access_token, + ) + if tokenizer_model.pad_id is not None: + pad_id = tokenizer_model.pad_id + elif tokenizer_model.unk_id is not None: + pad_id = tokenizer_model.unk_id + else: + pad_id = 0 + return tokenizer_model, pad_id + + +def validate_and_configure_sft_columns(data_columns, tokenizer_model, chat_template=None): + """Validates SFT data columns and configures the tokenizer chat template.""" + if chat_template and hasattr(tokenizer_model, "chat_template"): + tokenizer_model.chat_template = chat_template + + supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] + assert any( + set(data_columns) == set(supported) for supported in supported_columns + ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}" + + +def get_local_batch_size(config): + """Computes local batch size based on process count and expansion factor.""" + batch_size = config.global_batch_size_to_load // jax.process_count() + if config.expansion_factor_real_data > 1: + # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. + # But when using Grain, we want to keep the batch_size consistent with that in the checkpoint. + # We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py. + batch_size = int(batch_size // config.expansion_factor_real_data) + return batch_size + + +def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model): + """Packs or pads the dataset according to config and batches it.""" + if config.packing: + length_struct = {col: config.max_target_length for col in data_columns} + max_segments = config.max_segments_per_seq + if max_segments is not None and max_segments <= 0: + max_segments = None + if config.grain_packing_type == "first_fit": + dataset = grain.experimental.FirstFitPackIterDataset( + dataset, + length_struct=length_struct, + num_packing_bins=batch_size, + max_sequences_per_bin=max_segments, + ) + elif config.grain_packing_type == "best_fit": + dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size) + elif config.grain_packing_type == "concat_then_split": + if config.add_bos and hasattr(tokenizer_model, "bos_id"): + dataset = grain.experimental.ConcatThenSplitIterDataset( + dataset, + length_struct=length_struct, + bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS, + bos_token_id=tokenizer_model.bos_id, + ) + else: + dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct) + else: + raise ValueError(f"Unknown packing type: {config.packing}") + + rekey_dict = { + "targets_segmentation": "targets_segment_ids", + "inputs_segmentation": "inputs_segment_ids", + "targets_position": "targets_positions", + "inputs_position": "inputs_positions", + } + dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) + else: + dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) + + batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) + dataset = dataset.batch(batch_size, batch_fn=batch_fn) + return dataset + + +def shift_dataset(dataset, pad_id): + """Shift tokens to create inputs and targets for standard next-token prediction.""" + return dataset.map( + input_pipeline_utils.ShiftData( + ignored_ids=[pad_id], + axis=1, + ) + ) + + +def apply_multiprocessing_and_prefetch(dataset, config, grain_worker_count, grain_per_worker_buffer_size): + """Applies multiprocessing and prefetching configurations to the dataset.""" + multiprocessing_options = ( + pick_performance_config( + ds=dataset, + ram_budget_mb=config.grain_ram_budget_mb, + max_workers=None, + max_buffer_size=None, + ).multiprocessing_options + if grain_worker_count == -1 + else grain.MultiprocessingOptions( + num_workers=grain_worker_count, + per_worker_buffer_size=grain_per_worker_buffer_size, + ) + ) + return dataset.mp_prefetch(multiprocessing_options) diff --git a/src/maxtext/input_pipeline/grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py index 4488a71753..6e436c83ff 100644 --- a/src/maxtext/input_pipeline/grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -23,13 +23,12 @@ import jax -from grain.experimental import BestFitPackIterDataset, pick_performance_config import grain.python as grain +from maxtext.input_pipeline import data_processing_utils from maxtext.input_pipeline import input_pipeline_utils from maxtext.input_pipeline import grain_tokenizer from maxtext.input_pipeline import multihost_dataloading -from maxtext.input_pipeline import tokenizer from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -215,28 +214,12 @@ def pretrain_preprocessing_pipeline( grain_per_worker_buffer_size, ): """Use grain pipeline to pre-process the dataset and return iterators for pretrain""" - if config.grain_file_type in ("arrayrecord", "tfrecord"): - dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) - dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) - else: - dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) + dataset = data_processing_utils.parse_and_keep_features(dataset, config, data_columns, tokenize) assert len(data_columns) == 1 text_column = data_columns[0] - tokenizer_model = tokenizer.build_tokenizer( - config.tokenizer_path, - config.tokenizer_type, - config.add_bos, - config.add_eos, - config.hf_access_token, - ) - if tokenizer_model.pad_id is not None: - pad_id = tokenizer_model.pad_id - elif tokenizer_model.unk_id is not None: - pad_id = tokenizer_model.unk_id - else: - pad_id = -1 + tokenizer_model, pad_id = data_processing_utils.get_tokenizer_and_pad_id(config) if tokenize: if config.use_truncation: @@ -248,74 +231,13 @@ def pretrain_preprocessing_pipeline( rekey_dict = {col: text_column for col in data_columns} dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) - # Pack and Batch examples. - batch_size = config.global_batch_size_to_load // jax.process_count() - if config.expansion_factor_real_data > 1: - # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. - # But when using Grain, we want to keep the batch_size consistent with that in the checkpoint. - # We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py. - batch_size = int(batch_size // config.expansion_factor_real_data) - - if config.packing: - length_struct = {col: config.max_target_length for col in data_columns} - max_segments = config.max_segments_per_seq - if max_segments is not None and max_segments <= 0: - max_segments = None - if config.grain_packing_type == "first_fit": - dataset = grain.experimental.FirstFitPackIterDataset( - dataset, - length_struct=length_struct, - num_packing_bins=batch_size, - max_sequences_per_bin=max_segments, - ) - elif config.grain_packing_type == "best_fit": - dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size) - elif config.grain_packing_type == "concat_then_split": - if config.add_bos and hasattr(tokenizer_model, "bos_id"): - dataset = grain.experimental.ConcatThenSplitIterDataset( - dataset, - length_struct=length_struct, - bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS, - bos_token_id=tokenizer_model.bos_id, - ) - else: - dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct) - else: - raise ValueError(f"Unknown packing type: {config.packing}") - - rekey_dict = { - "targets_segmentation": "targets_segment_ids", - "inputs_segmentation": "inputs_segment_ids", - "targets_position": "targets_positions", - "inputs_position": "inputs_positions", - } - dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) - else: - dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) - batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) - dataset = dataset.batch(batch_size, batch_fn=batch_fn) + batch_size = data_processing_utils.get_local_batch_size(config) + dataset = data_processing_utils.format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model) - # Shift inputs for teacher-forced training - dataset = dataset.map( - input_pipeline_utils.ShiftData( - ignored_ids=[pad_id], - axis=1, - ) - ) - multiprocessing_options = ( - pick_performance_config( - ds=dataset, - ram_budget_mb=config.grain_ram_budget_mb, - max_workers=None, - max_buffer_size=None, - ).multiprocessing_options - if grain_worker_count == -1 - else grain.MultiprocessingOptions( - num_workers=grain_worker_count, - per_worker_buffer_size=grain_per_worker_buffer_size, - ) + dataset = data_processing_utils.shift_dataset(dataset, pad_id) + dataset = data_processing_utils.apply_multiprocessing_and_prefetch( + dataset, config, grain_worker_count, grain_per_worker_buffer_size ) - dataset = dataset.mp_prefetch(multiprocessing_options) return dataset @@ -328,22 +250,8 @@ def dpo_preprocessing_pipeline( grain_per_worker_buffer_size, ): """Use grain to pre-process the dataset and return iterators for dpo fine-tuning""" - if config.grain_file_type in ("arrayrecord", "tfrecord"): - dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) - dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) - tokenizer_model = tokenizer.build_tokenizer( - config.tokenizer_path, - config.tokenizer_type, - config.add_bos, - config.add_eos, - config.hf_access_token, - ) - if tokenizer_model.pad_id is not None: - pad_id = tokenizer_model.pad_id - elif tokenizer_model.unk_id is not None: - pad_id = tokenizer_model.unk_id - else: - pad_id = -1 + dataset = data_processing_utils.parse_and_keep_features(dataset, config, data_columns, tokenize) + tokenizer_model, pad_id = data_processing_utils.get_tokenizer_and_pad_id(config) if tokenize: dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model)) @@ -352,20 +260,96 @@ def dpo_preprocessing_pipeline( batch_size = config.global_batch_size_to_load // jax.process_count() batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) dataset = dataset.batch(batch_size, batch_fn=batch_fn) - multiprocessing_options = ( - pick_performance_config( - ds=dataset, - ram_budget_mb=config.grain_ram_budget_mb, - max_workers=None, - max_buffer_size=None, - ).multiprocessing_options - if grain_worker_count == -1 - else grain.MultiprocessingOptions( - num_workers=grain_worker_count, - per_worker_buffer_size=grain_per_worker_buffer_size, + dataset = data_processing_utils.apply_multiprocessing_and_prefetch( + dataset, config, grain_worker_count, grain_per_worker_buffer_size + ) + return dataset + + +def _format_chat_template_grain(element, data_columns, tokenizer_model): + """Grain-compatible mapping function to format raw columns into conversational messages.""" + # Convert raw columns to conversational messages + if "messages" in data_columns: + messages = element["messages"] + elif set(data_columns) == {"prompt", "completion"}: + messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}] + elif set(data_columns) == {"question", "answer"}: + messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}] + else: + # Fallback if it's already a single string + messages = element[data_columns[0]] + + assert all( + hasattr(m, "__contains__") and "role" in m and "content" in m for m in messages + ), f"SFT requires a conversational format. Expected dicts with 'role' and 'content', but got: {messages}" + + # Assign the standardized messages back to the primary column + element[data_columns[0]] = messages + + return input_pipeline_utils.apply_chat_template( + element, tokenizer_model=tokenizer_model, data_column_name=data_columns[0] + ) + + +def _tokenize_sft_chunks(element, text_column_name, tokenizer_model): + """Tokenize each chunk individually without truncating.""" + text_chunks = element[text_column_name] + element[text_column_name] = [tokenizer_model.encode(chunk) for chunk in text_chunks] + return element + + +def sft_preprocessing_pipeline( + dataset, + config, + data_columns, + tokenize, + grain_worker_count, + grain_per_worker_buffer_size, +): + """Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning""" + dataset = data_processing_utils.parse_and_keep_features(dataset, config, data_columns, tokenize) + + tokenizer_model, pad_id = data_processing_utils.get_tokenizer_and_pad_id(config) + base_tokenizer_model = tokenizer_model + + tokenizer_model = getattr(tokenizer_model, "tokenizer", tokenizer_model) + + data_processing_utils.validate_and_configure_sft_columns( + data_columns, tokenizer_model, getattr(config, "chat_template", None) + ) + + dataset = dataset.map( + functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model) + ) + + if tokenize: + dataset = dataset.map( + functools.partial( + _tokenize_sft_chunks, + text_column_name=data_columns[0], + tokenizer_model=tokenizer_model, + ) + ) + + dataset = dataset.map( + input_pipeline_utils.SFTPromptMasking( + text_column_name=data_columns[0], + completion_only=config.sft_train_on_completion_only, + max_target_length=config.max_target_length, + unk_id=pad_id, ) ) - dataset = dataset.mp_prefetch(multiprocessing_options) + data_columns = ("inputs", "targets") + + batch_size = data_processing_utils.get_local_batch_size(config) + dataset = data_processing_utils.format_and_batch( + dataset, config, batch_size, pad_id, data_columns, base_tokenizer_model + ) + + dataset = data_processing_utils.shift_dataset(dataset, pad_id) + dataset = data_processing_utils.apply_multiprocessing_and_prefetch( + dataset, config, grain_worker_count, grain_per_worker_buffer_size + ) return dataset @@ -403,6 +387,15 @@ def make_grain_train_iterator( grain_worker_count=config.grain_worker_count, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, ) + elif config.use_sft: + train_dataloader = sft_preprocessing_pipeline( + train_ds, + config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) else: train_dataloader = pretrain_preprocessing_pipeline( train_ds, @@ -434,7 +427,16 @@ def make_grain_train_iterator( ) if config.use_dpo: preprocessing_fn = functools.partial( - pretrain_preprocessing_pipeline, + dpo_preprocessing_pipeline, + config=config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) + elif config.use_sft: + preprocessing_fn = functools.partial( + sft_preprocessing_pipeline, config=config, data_columns=config.train_data_columns, tokenize=config.tokenize_train_data, @@ -502,6 +504,15 @@ def make_grain_eval_iterator( grain_worker_count=config.grain_worker_count_eval, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, ) + elif config.use_sft: + eval_dataloader = sft_preprocessing_pipeline( + eval_ds, + config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) else: eval_dataloader = pretrain_preprocessing_pipeline( eval_ds, @@ -537,6 +548,15 @@ def make_grain_eval_iterator( grain_worker_count=config.grain_worker_count_eval, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, ) + elif config.use_sft: + preprocessing_fn = functools.partial( + sft_preprocessing_pipeline, + config=config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) else: preprocessing_fn = functools.partial( pretrain_preprocessing_pipeline, diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index 649ff29876..d1be2c4890 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -26,6 +26,7 @@ import numpy as np +from maxtext.input_pipeline import data_processing_utils from maxtext.input_pipeline import input_pipeline_utils from maxtext.input_pipeline import instruction_data_processing from maxtext.input_pipeline import multihost_dataloading @@ -248,13 +249,7 @@ def preprocessing_pipeline( dataset = dataset.select_columns(data_column_names) if use_sft: - if chat_template: - tokenizer.chat_template = chat_template - - supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] - assert any( - set(data_column_names) == set(supported) for supported in supported_columns - ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}" + data_processing_utils.validate_and_configure_sft_columns(data_column_names, tokenizer, chat_template) # convert instruction dataset to conversational format # currently only works for Q&A datasets diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index f293f14e7c..a885e0109f 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -19,6 +19,7 @@ import tempfile import unittest import json +import numpy as np import jax import pytest @@ -387,5 +388,70 @@ def setUp(self): self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) +class GrainSFTParquetProcessingTest(unittest.TestCase): + """Tests the SFT pipeline end-to-end using the real ultrachat_200k parquet dataset.""" + + def setUp(self): + super().setUp() + + grain_train_file = "gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet" + base_output_directory = "gs://max-experiments/" + config_file = get_test_config_path() + + self.config = pyconfig.initialize( + [sys.argv[0], config_file], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory=base_output_directory, + dataset_type="grain", + grain_file_type="parquet", + grain_train_files=grain_train_file, + use_sft=True, # Triggers your new SFT pipeline + sft_train_on_completion_only=True, + train_data_columns=["messages"], + tokenizer_type="huggingface", + tokenizer_path="HuggingFaceH4/zephyr-7b-beta", # The ungated tokenizer + max_target_length=128, + packing=True, + grain_worker_count=1, + grain_per_worker_buffer_size=1, + enable_checkpointing=False, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + + def test_train_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.train_iter) + + # Assert all the required packing and target tensors were generated + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + # check to see that if prompts are masked, targets will differ from inputs + has_masked_tokens = np.any(batch["inputs"] != batch["targets"]) + self.assertTrue(bool(has_masked_tokens), "Targets array should differ from inputs array due to prompt masking.") + + if __name__ == "__main__": unittest.main()