Grain is a library for reading data for training and evaluating JAX models. It’s designed to be:
- Powerful: Users can bring arbitrary Python transformations.
- Flexible: Users can readily override Grain components for their needs.
- Deterministic: Multiple runs of the same pipeline will produce the same outputs.
- Resilient to preemptions: With minimal-sized checkpoints, users can resume the dataloader from the point at which it was preempted and produce the same output as if it was never preempted.
- Performant: Achieved with multiprocessing with shared memory. Tested on multiple data modalities.
- With minimal dependencies: Does not depend on ML frameworks (Tensorflow).
Determinism in a data input pipeline means that the same input data always results in the same sequence of batches at each step. This is typically achieved by setting a fixed shuffle seed during pipeline initialization. In an ideal scenario, where training runs uninterrupted, this determinism is straightforward (deterministic without preemption). However, real-world distributed training environments often face preemptions due to maintenance, hardware failures, or resource constraints. When a preempted training run resumes, the data input pipeline is re-initialized. If the same shuffle seed is used, the pipeline restarts from the beginning, potentially re-training the model on initial data. Conversely, a new seed produces a different batch sequence, making it difficult to track which data has been seen and how often each example is used for training. This lack of control can impact model performance and reproducibility.
Grain ensures determinism in data input pipelines by saving the pipeline's state, including dataset metadata and processed data indices, within a small JSON file in checkpoints (see the checkpointing section in Grain DataLoader tutorial or grain.checkpoint module). When a training run is resumed with the same dataset and shuffle seed, Grain restores the pipeline's exact state from the checkpoint. This enables fully deterministic, reproducible training that is resilient to disruptions.
- Model sensitive to repetition: When models are sensitive to the frequency with which they encounter specific examples, precise control over the order and repetition of data during training is essential. All LLMs belong to this category.
- Convergence comparison: In sensitive convergence experiments like testing quantization techniques, maintaining identical data batches between runs (e.g., quantized vs. unquantized) is essential for comparison. Determinism ensures consistency even when the runs are long and undergo saving/resuming at different steps.
- Debug training anomalies: When troubleshooting training spikes or anomalies, the ability to replay the exact data sequence helps distinguish between bad data batches and underlying hardware or software issues.
- Global shuffle: This feature is only available when using Grain with ArrayRecord (random access) format, achieved by shuffling indices globally at the beginning of each epoch and then reading the elements according to the random order. This shuffle method effectively prevents local overfitting, leading to better training results.
- Hierarchical shuffle: For sequential access format Parquet, shuffle is performed by these steps: file shuffling, interleave from files, and window shuffle using a fixed size buffer.
- Grain currently supports three data formats: ArrayRecord (random access), Parquet (partial random-access through row groups) and TFRecord(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see Apache Beam Integration for ArrayRecord. Additionally, other random access data sources can be supported via a custom data source class.
- Community Resource: The MaxText community has created a ArrayRecord Documentation. Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
- If the dataset is hosted on a Cloud Storage bucket, the path
gs://can be provided directly. However, for the best performance, it's recommended to read the bucket through Cloud Storage FUSE. This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in setup.sh. The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script setup_gcsfuse.sh. The script configures some parameters for the mount.
bash tools/setup/setup_gcsfuse.sh \
DATASET_GCS_BUCKET=${BUCKET_NAME?} \
MOUNT_PATH=${MOUNT_PATH?} \
[FILE_PATH=${MOUNT_PATH?}/my_dataset]Note that FILE_PATH is optional; when provided, the script runs ls -R for pre-filling the metadata cache (see "Performance tuning best practices" on the Google Cloud documentation).
-
Set
dataset_type=grain,grain_file_type={arrayrecord|parquet|tfrecord},grain_train_filesinsrc/maxtext/configs/base.ymlor through command line arguments to match the file pattern on the mounted local path. -
Tune
grain_worker_countfor performance. This parameter controls the number of child processes used by Grain (more details in behind_the_scenes, grain_pool.py). If you use a large number of workers, check your config for gcsfuse in setup_gcsfuse.sh to avoid gcsfuse throttling. -
ArrayRecord Only: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:
# Blend two data sources with 30% from first source and 70% from second source grain_train_files=/tmp/gcsfuse/dataset1.array_record*,0.3;/tmp/gcsfuse/dataset2.array_record*,0.7 # Blend three data sources with equal weights (will be normalized to 0.33 each) grain_train_files=/tmp/gcsfuse/dataset1.array_record*,1;/tmp/gcsfuse/dataset2.array_record*,1;/tmp/gcsfuse/dataset3.array_record*,1Advanced usage: updating the data mixture when resuming training from a checkpoint. To use this feature, define the data mixture and name the datasets in a JSON file, and set the
grain_train_mixture_config_pathflag to point to this file. When resuming from a checkpoint, you can provide a new JSON file with an updated mixture. This allows for dynamic adjustment of data sources and their weights.For example, you can start a training run with a
grain_mixture.jsonfile:{ "ds1": { "path": "gs://path/to/dataset1.array_record*", "weight": 0.4 }, "ds2": { "path": "gs://path/to/dataset2.array_record*", "weight": 0.6 } }Then, you can resume the training run with a different mixture in
grain_mixture2.json, which adds a new dataset:{ "ds1": { "path": "gs://path/to/dataset1.array_record*", "weight": 0.5 }, "ds2": { "path": "gs://path/to/dataset2.array_record*", "weight": 0.3 }, "ds3": { "path": "gs://path/to/dataset3.array_record*", "weight": 0.2 } }Similarly, you can remove datasets or change weights in the mixture. Grain will correctly handle the state of the data iterators.
Packing and multi-process prefetching (mp_prefetch) operations rely on buffers. When a data mixture is updated, these buffers cannot be recovered, leading to discarded unused elements and thus minor skipping in the training data. -
Example command:
bash tools/setup/setup_gcsfuse.sh \
DATASET_GCS_BUCKET=maxtext-dataset \
MOUNT_PATH=/tmp/gcsfuse && \
python3 -m maxtext.trainers.pre_train.train \
run_name=<RUN_NAME> base_output_directory=gs://<MY_BUCKET> \
dataset_type=grain \
grain_file_type=arrayrecord # or parquet \
grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \
grain_worker_count=2- Using validation set for evaluation
When setting eval_interval > 0, evaluation will be run with a specified eval dataset. Example config (set in src/maxtext/configs/base.yml or through command line):
eval_interval: 10000
eval_steps: 50
grain_eval_files: '/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*'- Experimental: resuming training with a different chip count
In Grain checkpoints, each data-loading host has a corresponding JSON file. For cases where a user wants to resume training with a different number of data-loading hosts, MaxText provides an experimental feature:
- Scaling up: For example, if you have a checkpoint from 64 data-loading hosts and want to resume training with 128. This is achieved by having a subset of the hosts load the real data, which is then sent to the other hosts. The flag
expansion_factor_real_data(default is -1) controls this behavior. When set to a value greater than 1, the number of hosts loading real data istotal number of hosts // expansion_factor_real_data. Each of these data-loading hosts will loadexpansion_factor_real_data * per_host_batch_size_to_train. For code integrity, the non-loading hosts use aPlaceHolderDataIteratorto generate dummy data, which is later discarded. A user can optionally setmax_checkify=trueto enable additional checks that ensure dummy data is not used for training. In this example, you would setexpansion_factor_real_data=2to scale from 64 to 128 hosts. - Scaling down: For example, if you have a checkpoint from 128 data-loading hosts and want to resume with 64. This is achieved by restoring multiple data iterators on each host. Set flag
expansion_factor_real_datato have each host restore1 / expansion_factor_real_datadata iterators. We then alternate between these iterators to produce batches. In this example, you would setexpansion_factor_real_data=0.5to scale from 128 down to 64 hosts. - Note: In both scaling up and scaling down scenarios, the
per_device_batch_sizemust remain consistent. This is because Grain records the number of iterations (batches) in the iterator's state, and changing the batch size will result in either skipping or duplicating data.