Skip to content

Commit 146edf3

Browse files
aireenmeicopybara-github
authored andcommitted
PR #3430: Support TFrecord in Grain pipeline
Imported from GitHub PR AI-Hypercomputer/maxtext#3430 # Description * Support grain_file_type=tfrecord in Grain pipeline * Add unit test `GrainTFRecordProcessingTest` * Refactor Grain unit test for cleaner code * Add docs, the formatting changes in `data_input_pipeline.md` is made by pre-commit mdformat # Tests * unit test * test run with command: ``` python3 -m MaxText.train maxtext/configs/base.yml \ run_name=${RUN_NAME} \ base_output_directory=${GCS_BUCKET} \ dataset_type=grain \ grain_file_type=tfrecord \ grain_train_files=gs://maxtext-dataset/c4/en/3.0.1/c4-train.tfrecord-* \ grain_worker_count=1 \ tokenizer_type=huggingface \ tokenizer_path=google-t5/t5-large \ add_bos=false \ num_epoch=1 \ steps=10 \ enable_checkpointing=false ``` log: ``` INFO:absl:completed step: 0, seconds: 32.063, TFLOP/s/device: 5.085, Tokens/s/device: 766.485, total_weights: 97543, loss: 10.871 INFO:absl:To see full metrics 'tensorboard --logdir=gs://aireenmei-multipod/test/grain/grain-2026-03-16-20-30-35/tensorboard/' INFO:absl:completed step: 1, seconds: 0.432, TFLOP/s/device: 377.451, Tokens/s/device: 56893.498, total_weights: 96426, loss: 10.866 INFO:absl:completed step: 2, seconds: 1.056, TFLOP/s/device: 154.471, Tokens/s/device: 23283.531, total_weights: 96824, loss: 9.716 INFO:absl:completed step: 3, seconds: 0.927, TFLOP/s/device: 175.932, Tokens/s/device: 26518.364, total_weights: 94917, loss: 9.183 INFO:absl:completed step: 4, seconds: 1.202, TFLOP/s/device: 135.602, Tokens/s/device: 20439.394, total_weights: 96144, loss: 8.930 INFO:absl:completed step: 5, seconds: 1.202, TFLOP/s/device: 135.628, Tokens/s/device: 20443.389, total_weights: 96021, loss: 8.782 INFO:absl:completed step: 6, seconds: 1.203, TFLOP/s/device: 135.565, Tokens/s/device: 20433.888, total_weights: 96877, loss: 8.595 INFO:absl:completed step: 7, seconds: 1.203, TFLOP/s/device: 135.544, Tokens/s/device: 20430.609, total_weights: 96736, loss: 8.476 INFO:absl:completed step: 8, seconds: 1.202, TFLOP/s/device: 135.641, Tokens/s/device: 20445.226, total_weights: 95375, loss: 8.465 INFO:absl:completed step: 9, seconds: 1.203, TFLOP/s/device: 135.553, Tokens/s/device: 20432.002, total_weights: 96435, loss: 8.375 ``` In comparison, using the same TFRecord files in tfds pipeline is very slow at the beginning ``` INFO:absl:completed step: 0, seconds: 27.974, TFLOP/s/device: 5.828, Tokens/s/device: 878.532, total_weights: 81974, loss: 10.873 INFO:absl:To see full metrics 'tensorboard --logdir=gs://aireenmei-multipod/test/tfds/tfds-2026-03-17-23-33-04/tensorboard/' INFO:absl:completed step: 1, seconds: 6.055, TFLOP/s/device: 26.929, Tokens/s/device: 4059.096, total_weights: 79833, loss: 10.882 INFO:absl:completed step: 2, seconds: 7.814, TFLOP/s/device: 20.867, Tokens/s/device: 3145.257, total_weights: 82948, loss: 9.759 INFO:absl:completed step: 3, seconds: 5.320, TFLOP/s/device: 30.645, Tokens/s/device: 4619.175, total_weights: 75398, loss: 9.151 INFO:absl:completed step: 4, seconds: 5.777, TFLOP/s/device: 28.223, Tokens/s/device: 4254.122, total_weights: 81795, loss: 9.026 INFO:absl:completed step: 5, seconds: 6.034, TFLOP/s/device: 27.020, Tokens/s/device: 4072.730, total_weights: 83248, loss: 8.848 INFO:absl:completed step: 6, seconds: 2.280, TFLOP/s/device: 71.522, Tokens/s/device: 10780.621, total_weights: 79633, loss: 8.697 INFO:absl:completed step: 7, seconds: 0.010, TFLOP/s/device: 16720.905, Tokens/s/device: 2520356.886, total_weights: 79027, loss: 8.530 INFO:absl:completed step: 8, seconds: 1.202, TFLOP/s/device: 135.698, Tokens/s/device: 20453.887, total_weights: 79079, loss: 8.477 INFO:absl:completed step: 9, seconds: 1.202, TFLOP/s/device: 135.611, Tokens/s/device: 20440.771, total_weights: 78018, loss: 8.433 ``` # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files). Copybara import of the project: -- 9e365436e30735554772a91ac1e3c2311544b757 by aireenmei <aireenmei@gmail.com>: Support TFrecord in Grain pipeline Merging this change closes #3430 PiperOrigin-RevId: 886960115
1 parent 9cd3ac9 commit 146edf3

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

  • grain/_src/python/dataset/sources

grain/_src/python/dataset/sources/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ py_library(
2929
name = "tfrecord_dataset",
3030
srcs = ["tfrecord_dataset.py"],
3131
srcs_version = "PY3",
32+
visibility = [
33+
"//third_party/py/grain:internal",
34+
"//third_party/py/maxtext:__pkg__",
35+
],
3236
deps = ["//grain/_src/python/dataset"],
3337
)
3438

0 commit comments

Comments
 (0)