Skip to content

Added SFT Pre-Processing for Grain Input Pipeline#3437

Merged
copybara-service[bot] merged 10 commits intomainfrom
ajkv/sft-grain-implementation
Apr 1, 2026
Merged

Added SFT Pre-Processing for Grain Input Pipeline#3437
copybara-service[bot] merged 10 commits intomainfrom
ajkv/sft-grain-implementation

Conversation

@ajkv-google
Copy link
Copy Markdown
Collaborator

@ajkv-google ajkv-google commented Mar 18, 2026

Description

This PR introduces SFT support to the Grain input pipeline by adding a separate sft_preprocessing_pipeline function. Rather than cluttering the existing pretrain code, it uses simple conditionals inside the train and eval iterators to route to this new SFT logic. I followed the existing Hugging Face SFT implementation and adapted its logic to be compatible with Grain's element-wise datasets.

Tests

I added a unit test to verify end-to-end functionality to make sure the Grain SFT pipeline formats the data and outputs correctly. Ran this command to execute the unit test:

  • pytest tests/unit/grain_data_processing_test.py::GrainSFTParquetProcessingTest -v

This is the output of the test: Test Passed Output

Also, ran the training pipeline in Maxtext with sft enabled using a grain dataset with this command:

  • python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=test_grain_sft dataset_type=grain grain_file_type=parquet grain_train_files=gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet steps=10 tokenizer_type=huggingface tokenizer_path=HuggingFaceH4/zephyr-7b-beta

Verified that the sft processing changes worked and trained successfully: Logs

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 61.68224% with 41 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/maxtext/input_pipeline/grain_data_processing.py 51.02% 23 Missing and 1 partial ⚠️
...rc/maxtext/input_pipeline/data_processing_utils.py 69.64% 11 Missing and 6 partials ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@vlad-karp vlad-karp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be great to test not only with maxtext general sft but with distillation sft pipeline as well

Comment thread src/maxtext/input_pipeline/grain_data_processing.py Outdated
Comment thread src/maxtext/input_pipeline/grain_data_processing.py Outdated
Comment thread src/maxtext/input_pipeline/grain_data_processing.py Outdated
Comment thread src/maxtext/input_pipeline/grain_data_processing.py
Comment thread src/maxtext/input_pipeline/grain_data_processing.py
Comment thread src/maxtext/input_pipeline/grain_data_processing.py Outdated
Comment thread src/maxtext/input_pipeline/grain_data_processing.py Outdated
Comment thread src/maxtext/input_pipeline/grain_data_processing.py
Comment thread src/maxtext/input_pipeline/grain_data_processing.py Outdated
Comment thread src/maxtext/input_pipeline/grain_data_processing.py
Comment thread src/maxtext/input_pipeline/data_processing_utils.py Outdated
Comment thread src/maxtext/input_pipeline/data_processing_utils.py Outdated
Comment thread src/maxtext/input_pipeline/grain_data_processing.py
Copy link
Copy Markdown
Collaborator

@JamesDeng42 JamesDeng42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change is clean, left one comments and LGTM.

@ajkv-google ajkv-google force-pushed the ajkv/sft-grain-implementation branch from 4f4c45c to c7bd12f Compare March 31, 2026 20:36
@copybara-service copybara-service Bot merged commit f2216e2 into main Apr 1, 2026
39 of 42 checks passed
@copybara-service copybara-service Bot deleted the ajkv/sft-grain-implementation branch April 1, 2026 00:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants