diff --git a/docs/guides/data_input_pipeline.md b/docs/guides/data_input_pipeline.md index 53b7c2924b..8772db513e 100644 --- a/docs/guides/data_input_pipeline.md +++ b/docs/guides/data_input_pipeline.md @@ -15,29 +15,34 @@ --> (data-input-pipeline)= + # Data pipelines Currently MaxText has three data input pipelines: -| Pipeline | Dataset formats | Features | Limitations | -| -------- | --------------- | -------- | ----------- | -| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended)| [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))
[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle
With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | | -| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)
local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience;
multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage);
non-deterministic with preemption
(deterministic without preemption)
| -| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords;
non-deterministic with preemption
(deterministic without preemption) | +| Pipeline | Dataset formats | Features | Limitations | +| ------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended) | [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))
[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview))
[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle
With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | | +| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)
local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience;
multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage);
non-deterministic with preemption
(deterministic without preemption)
| +| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords;
non-deterministic with preemption
(deterministic without preemption) | ## Multihost dataloading best practice + Training in a multi-host environment presents unique challenges for data input pipelines. An effective data loading strategy must address three key issues: + 1. **Concurrent access**: Multiple hosts need to read from the same dataset simultaneously without causing conflicts. 2. **Data uniqueness**: Each host must be fed a unique, non-overlapping subset of the data to ensure the model sees each example correctly. -3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging. -The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access. +3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging. + The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access. ### Random access dataset (Recommended) + Random-access formats are highly recommended for multi-host training because they allow any part of the file to be read directly by its index.
In MaxText, this is best supported by the ArrayRecord format using the Grain input pipeline. This approach gracefully handles the key challenges: -* **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file. -* **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. +- **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file. + +- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. ```{note} When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this. @@ -48,12 +53,14 @@ If all hosts exhaust their data before the target step count is reached, both `t ``` ### Sequential access dataset -* **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files. -* **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early. -```{toctree} -:hidden: +- **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files. +- **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early. +```{toctree} +--- +hidden: +--- data_input_pipeline/data_input_grain data_input_pipeline/data_input_hf data_input_pipeline/data_input_tfds diff --git a/docs/guides/data_input_pipeline/data_input_grain.md b/docs/guides/data_input_pipeline/data_input_grain.md index a125cb2a13..1191d2ff7b 100644 --- a/docs/guides/data_input_pipeline/data_input_grain.md +++ b/docs/guides/data_input_pipeline/data_input_grain.md @@ -32,9 +32,9 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state ## Using Grain -1. Grain currently supports two data formats: [ArrayRecord](https://github.com/google/array_record) (random access) and [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class. +1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class. - **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet. -2. When the dataset is hosted on a Cloud Storage bucket, Grain can read it through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount. +2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount. ```sh bash tools/setup/setup_gcsfuse.sh \ @@ -45,7 +45,7 @@ MOUNT_PATH=${MOUNT_PATH?} \ Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://cloud.google.com/storage/docs/cloud-storage-fuse/performance#improve-first-time-reads)). -1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path. +1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path. 2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling. diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 398df849fe..48ca23757f 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -679,6 +679,7 @@ grain_ram_budget_mb: 1024 # RAM budget (MB) for auto-tuning worker count. Only u grain_num_threads_eval: 16 grain_prefetch_buffer_size_eval: 500 grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources. +grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access formats such as Parquet, TFRecord. # for using pathways colocated_python_data_input: False # experimental feature, under testing diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 3df51ac106..680fa93dda 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1030,17 +1030,13 @@ class GrainDataset(BaseModel): "", description="Path to a JSON file specifying the mixture weights for Grain training data.", ) - grain_file_type: str = Field("arrayrecord", description="File type for Grain data.") - grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.") - grain_per_worker_buffer_size: int = Field( - 1, - description="Buffer size for each worker for Grain data loading during training.", + grain_file_type: str = Field( + "arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet." ) + grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.") + grain_per_worker_buffer_size: int = Field(1, description="Per-worker buffer size for Grain train data loading.") grain_worker_count_eval: int = Field(1, description="Number of workers for Grain eval data loading.") - grain_per_worker_buffer_size_eval: int = Field( - 1, - description="Buffer size for each worker for Grain data loading during evaluation.", - ) + grain_per_worker_buffer_size_eval: int = Field(1, description="Per-worker buffer size for Grain eval data loading.") grain_ram_budget_mb: int = Field(1024, description="RAM budget (MB) for auto-tuning worker count.") grain_num_threads: int = Field(16, description="Number of threads for Grain ReadOptions during training.") grain_prefetch_buffer_size: int = Field(500, description="Prefetch buffer size for Grain ReadOptions during training.") @@ -1052,6 +1048,7 @@ class GrainDataset(BaseModel): 16, description="Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.", ) + grain_shuffle_buffer_size: int = Field(100, description="Shuffle buffer size when using Parquet or TFRecord.") class FineTuning(BaseModel): diff --git a/src/maxtext/input_pipeline/grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py index 154dac457a..4488a71753 100644 --- a/src/maxtext/input_pipeline/grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -76,6 +76,7 @@ def get_datasets( data_file_type, shuffle, shuffle_seed, + shuffle_buffer_size, num_epoch, dataloading_host_index, dataloading_host_count, @@ -167,6 +168,20 @@ def create_dataset_from_pattern(pattern): grain_prefetch_buffer_size, ) return dataset + elif data_file_type == "tfrecord": + data_files = find_data_files(data_file_pattern) + dataset = grain.MapDataset.source(data_files) + if shuffle: + dataset = dataset.shuffle(seed=shuffle_seed) + dataset = dataset.repeat(num_epoch) + dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding + dataset = dataset.map(input_pipeline_utils.make_tfrecord_iter_dataset) + files_per_host = max(len(data_files) // dataloading_host_count, 1) + cycle_length = min(files_per_host, grain_num_threads) + dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=cycle_length) + if shuffle: + dataset = grain.experimental.WindowShuffleIterDataset(dataset, window_size=shuffle_buffer_size, seed=shuffle_seed) + return dataset elif data_file_type == "parquet": data_files = find_data_files(data_file_pattern) dataset = grain.MapDataset.source(data_files) @@ -183,10 +198,12 @@ def create_dataset_from_pattern(pattern): cycle_length = min(len(dataset) // num_epoch, grain_num_threads) dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=cycle_length) if shuffle: - dataset = grain.experimental.WindowShuffleIterDataset(dataset, window_size=100, seed=shuffle_seed) + dataset = grain.experimental.WindowShuffleIterDataset(dataset, window_size=shuffle_buffer_size, seed=shuffle_seed) return dataset else: - raise ValueError(f"grain pipeline supports (arrayrecord, parquet) as grain_file_type, but got {data_file_type}") + raise ValueError( + f"grain pipeline supports (arrayrecord, tfrecord, parquet) as grain_file_type, but got {data_file_type}" + ) def pretrain_preprocessing_pipeline( @@ -198,7 +215,7 @@ def pretrain_preprocessing_pipeline( grain_per_worker_buffer_size, ): """Use grain pipeline to pre-process the dataset and return iterators for pretrain""" - if config.grain_file_type == "arrayrecord": + if config.grain_file_type in ("arrayrecord", "tfrecord"): dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) else: @@ -311,7 +328,7 @@ def dpo_preprocessing_pipeline( grain_per_worker_buffer_size, ): """Use grain to pre-process the dataset and return iterators for dpo fine-tuning""" - if config.grain_file_type == "arrayrecord": + if config.grain_file_type in ("arrayrecord", "tfrecord"): dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) tokenizer_model = tokenizer.build_tokenizer( @@ -367,6 +384,7 @@ def make_grain_train_iterator( config.grain_file_type, shuffle=config.enable_data_shuffling, shuffle_seed=config.data_shuffle_seed, + shuffle_buffer_size=config.grain_shuffle_buffer_size, num_epoch=config.num_epoch, dataloading_host_index=process_indices.index(jax.process_index()), dataloading_host_count=len(process_indices), @@ -407,6 +425,7 @@ def make_grain_train_iterator( config.grain_file_type, shuffle=config.enable_data_shuffling, shuffle_seed=config.data_shuffle_seed, + shuffle_buffer_size=config.grain_shuffle_buffer_size, num_epoch=config.num_epoch, grain_worker_count=config.grain_worker_count, grain_num_threads=config.grain_num_threads, @@ -465,6 +484,7 @@ def make_grain_eval_iterator( config.grain_file_type, shuffle=False, shuffle_seed=config.data_shuffle_seed, + shuffle_buffer_size=config.grain_shuffle_buffer_size, num_epoch=1, dataloading_host_index=process_indices.index(jax.process_index()), dataloading_host_count=len(process_indices), @@ -501,6 +521,7 @@ def make_grain_eval_iterator( config.grain_file_type, shuffle=False, # No shuffle for eval shuffle_seed=config.data_shuffle_seed, + shuffle_buffer_size=config.grain_shuffle_buffer_size, num_epoch=1, grain_worker_count=config.grain_worker_count_eval, grain_num_threads=config.grain_num_threads_eval, diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index e6fb9d9222..0048ff4193 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -25,10 +25,13 @@ import grain.python as grain import numpy as np +from grain._src.python.dataset.sources.tfrecord_dataset import _TFRecordReader, _TFRecordDatasetIterator # pylint: disable=protected-access +from grain.experimental import TFRecordIterDataset from maxtext.input_pipeline.protos import example_pb2 from maxtext.input_pipeline import tokenizer from maxtext.multimodal import processor as mm_processor from maxtext.multimodal import utils as mm_utils +from maxtext.utils import gcs_utils from maxtext.utils import max_logging Features = dict[str, Any] @@ -418,6 +421,38 @@ def __getitem__(self, index): ########## Functions used by Grain pipeline +class _GCSTFRecordReader(_TFRecordReader): + """Extends Grain's _TFRecordReader to open TFRecord files from GCS via streaming BlobReader.""" + + def __init__(self, path: str): + # Skip parent __init__ (which calls open(path, "rb")) and open via GCS BlobReader instead. + bucket_name, blob_name = gcs_utils.parse_gcs_bucket_and_prefix(path) + self._reader = gcs_utils.storage.Client().bucket(bucket_name).blob(blob_name).open("rb") + + +class _GCSTFRecordDatasetIterator(_TFRecordDatasetIterator): + """Extends Grain's _TFRecordDatasetIterator to use _GCSTFRecordReader for GCS paths.""" + + def __init__(self, path: str): + # Skip parent __init__ (which creates _TFRecordReader); use GCS-aware reader instead. + grain.DatasetIterator.__init__(self) + self._reader = _GCSTFRecordReader(path) + + +class GCSTFRecordIterDataset(TFRecordIterDataset): + """Extends Grain's TFRecordIterDataset to support GCS paths.""" + + def __iter__(self) -> grain.DatasetIterator: # pylint: disable=non-iterator-returned + return _GCSTFRecordDatasetIterator(self._path) + + +def make_tfrecord_iter_dataset(path: str): + """Returns the appropriate TFRecordIterDataset for local or GCS paths.""" + if path.startswith("gs://"): + return GCSTFRecordIterDataset(path) + return TFRecordIterDataset(path) + + @dataclasses.dataclass class ParseFeatures(grain.MapTransform): """Parse serialized example""" diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index 90700b3790..f293f14e7c 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -14,8 +14,6 @@ """Tests for grain data processing.""" -import glob -import subprocess import sys import os.path import tempfile @@ -30,70 +28,18 @@ from maxtext.configs import pyconfig from maxtext.input_pipeline import grain_data_processing from maxtext.input_pipeline import input_pipeline_interface -from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT +from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT from maxtext.common.gcloud_stub import is_decoupled from tests.utils.test_helpers import get_test_base_output_directory, get_test_config_path, get_test_dataset_path -class GrainArrayRecordProcessingTest(unittest.TestCase): +class GrainBaseProcessingTest: + """Base mixin with shared test methods for all grain data processing tests. - @classmethod - def setUpClass(cls): - super().setUpClass() - mount_gcsfuse() - - def setUp(self): - super().setUp() - temp_dir = tempfile.gettempdir() - decoupled = is_decoupled() - - if decoupled: - dataset_root = get_test_dataset_path() - grain_train_files = os.path.join( - dataset_root, - "c4", - "en", - "3.0.1", - "c4-train.array_record-*", - ) - base_output_directory = get_test_base_output_directory() - else: - grain_train_files = os.path.join( - temp_dir, - "gcsfuse", - "array-record", - "c4", - "en", - "3.0.1", - "c4-train.array_record*", - ) - base_output_directory = "gs://max-experiments/" - - config_file = get_test_config_path() - - self.config = pyconfig.initialize( - [sys.argv[0], config_file], - per_device_batch_size=1, - run_name="test", - mesh_axes=["data"], - logical_axis_rules=[["batch", "data"]], - data_sharding=["data"], - base_output_directory=base_output_directory, - dataset_type="grain", - grain_train_files=grain_train_files, - tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), - enable_checkpointing=False, - ) - self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) - self.process_indices = input_pipeline_interface.get_process_loading_real_data( - self.config.data_sharding, - self.config.global_batch_size_to_load, - self.config.global_batch_size_to_train_on, - self.config.max_target_length, - self.mesh, - ) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + Does not inherit from unittest.TestCase to prevent the test runner from + discovering and executing it directly. Concrete subclasses must also inherit + from unittest.TestCase (or a subclass thereof). + """ def test_train_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] @@ -112,7 +58,6 @@ def test_train_ds(self): }, ) - @pytest.mark.external_serving # Skipped in decoupled mode due to rocBLAS scratch buffer TF issues on GPU def test_batch_determinism(self): batch1 = next(self.train_iter) train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) @@ -133,45 +78,49 @@ def get_first_batch(iterator): train_batch1 = get_first_batch(self.train_iter) train_batch2 = get_first_batch(self.train_iter) - self.assertTrue((train_batch1["inputs"] == train_batch2["inputs"]).all()) - self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all()) + self.assertTrue((train_batch1["inputs"] == train_batch2["inputs"]).all()) # pytype: disable=unsupported-operands + self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all()) # pytype: disable=unsupported-operands -class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest): +class GrainArrayRecordProcessingTest(GrainBaseProcessingTest, unittest.TestCase): + """Test grain data processing with ArrayRecord format. + + In decoupled mode, reads directly from GCS. Otherwise, reads from GCSFUSE mounted path + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() def setUp(self): - # Override parent setUp to use multi-source blending + super().setUp() temp_dir = tempfile.gettempdir() decoupled = is_decoupled() if decoupled: dataset_root = get_test_dataset_path() - base_pattern = os.path.join( + grain_train_files = os.path.join( dataset_root, "c4", + "array-record", "en", "3.0.1", - "c4-train.array_record-*", + "c4-train.array_record-00000-of-01024", ) base_output_directory = get_test_base_output_directory() - config_file = get_test_config_path() else: - base_pattern = os.path.join( + grain_train_files = os.path.join( temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", - "c4-train.array_record*", + "c4-train.array_record-00000-of-01024", ) base_output_directory = "gs://max-experiments/" - config_file = get_test_config_path() - # Ensure GCS fuse mounted for cloud path usage - mount_gcsfuse() - - train_files_weighted = ";".join([f"{base_pattern},0.3", f"{base_pattern},0.7"]) + config_file = get_test_config_path() self.config = pyconfig.initialize( [sys.argv[0], config_file], per_device_batch_size=1, @@ -181,7 +130,7 @@ def setUp(self): data_sharding=["data"], base_output_directory=base_output_directory, dataset_type="grain", - grain_train_files=train_files_weighted, + grain_train_files=grain_train_files, tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), enable_checkpointing=False, ) @@ -196,6 +145,36 @@ def setUp(self): ) self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + def _make_config(self, **overrides): + """Re-initialize config with base params, applying any overrides.""" + kwargs = { + "per_device_batch_size": 1, + "run_name": "test", + "mesh_axes": ["data"], + "logical_axis_rules": [["batch", "data"]], + "data_sharding": ["data"], + "base_output_directory": self.config.base_output_directory, + "dataset_type": "grain", + "grain_train_files": self.config.grain_train_files, + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), + "enable_checkpointing": False, + **overrides, + } + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **kwargs) + + @pytest.mark.external_serving # Skipped in decoupled mode due to rocBLAS scratch buffer TF issues on GPU + def test_batch_determinism(self): + super().test_batch_determinism() + + +class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest): + + def setUp(self): + super().setUp() + train_files_weighted = ";".join([f"{self.config.grain_train_files},0.3", f"{self.config.grain_train_files},0.7"]) + self.config = self._make_config(grain_train_files=train_files_weighted) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + class GrainArrayRecordProcessingWithMixtureConfigTest(GrainArrayRecordProcessingTest): @@ -228,7 +207,6 @@ def setUp(self): "weight": 0.7, }, } - base_output_directory = get_test_base_output_directory() else: mixture_config = { "ds1": { @@ -240,33 +218,11 @@ def setUp(self): "weight": 0.7, }, } - base_output_directory = "gs://max-experiments/" self.mixture_config_path = os.path.join(temp_dir, "mixture_config.json") with open(self.mixture_config_path, "w", encoding="utf-8") as f: json.dump(mixture_config, f) - self.config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - per_device_batch_size=1, - run_name="test", - mesh_axes=["data"], - logical_axis_rules=[["batch", "data"]], - data_sharding=["data"], - base_output_directory=base_output_directory, - dataset_type="grain", - grain_train_mixture_config_path=self.mixture_config_path, - tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), - enable_checkpointing=False, - ) - self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) - self.process_indices = input_pipeline_interface.get_process_loading_real_data( - self.config.data_sharding, - self.config.global_batch_size_to_load, - self.config.global_batch_size_to_train_on, - self.config.max_target_length, - self.mesh, - ) + self.config = self._make_config(grain_train_mixture_config_path=self.mixture_config_path) self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) @@ -276,55 +232,8 @@ class GrainArrayRecordAutoTuneTest(GrainArrayRecordProcessingTest): """Test grain data processing with auto-tuning enabled (grain_worker_count=-1).""" def setUp(self): - temp_dir = tempfile.gettempdir() - decoupled = is_decoupled() - - if decoupled: - dataset_root = get_test_dataset_path() - grain_train_files = os.path.join( - dataset_root, - "c4", - "en", - "3.0.1", - "c4-train.array_record-*", - ) - base_output_directory = get_test_base_output_directory() - else: - grain_train_files = os.path.join( - temp_dir, - "gcsfuse", - "array-record", - "c4", - "en", - "3.0.1", - "c4-train.array_record*", - ) - base_output_directory = "gs://max-experiments/" - - self.config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - per_device_batch_size=1, - run_name="test", - mesh_axes=["data"], - logical_axis_rules=[["batch", "data"]], - data_sharding=["data"], - base_output_directory=base_output_directory, - dataset_type="grain", - grain_ram_budget_mb=512, - grain_train_files=grain_train_files, - grain_worker_count=-1, # Enable auto-tuning - tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), - enable_checkpointing=False, - ) - self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) - self.process_indices = input_pipeline_interface.get_process_loading_real_data( - self.config.data_sharding, - self.config.global_batch_size_to_load, - self.config.global_batch_size_to_train_on, - self.config.max_target_length, - self.mesh, - ) + super().setUp() + self.config = self._make_config(grain_ram_budget_mb=512, grain_worker_count=-1) # Enable auto-tuning self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) @pytest.mark.skip( @@ -345,37 +254,48 @@ class GrainArrayRecordBestFitPackingTest(GrainArrayRecordProcessingTest): """Test grain data processing with best_fit packing strategy.""" def setUp(self): + super().setUp() + self.config = self._make_config(grain_packing_type="best_fit") + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + + +class GrainParquetProcessingTest(GrainBaseProcessingTest, unittest.TestCase): + """Test grain data processing with Parquet format. + + In decoupled mode, reads directly from GCS. Otherwise, reads from GCSFUSE mounted path + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def setUp(self): + super().setUp() temp_dir = tempfile.gettempdir() decoupled = is_decoupled() if decoupled: dataset_root = get_test_dataset_path() - grain_train_files = os.path.join( + grain_train_file = os.path.join( dataset_root, + "hf", "c4", - "en", - "3.0.1", - "c4-train.array_record-*", + "c4-train-00000-of-01637.parquet", ) base_output_directory = get_test_base_output_directory() else: - mount_gcsfuse() - grain_train_files = os.path.join( + grain_train_file = os.path.join( temp_dir, "gcsfuse", - "array-record", + "hf", "c4", - "en", - "3.0.1", - "c4-train.array_record*", + "c4-train-00000-of-01637.parquet", ) - # If the external dataset isn't available, skip rather than failing. - if not glob.glob(grain_train_files): - pytest.skip(f"No files found matching pattern: {grain_train_files}") base_output_directory = "gs://max-experiments/" + config_file = get_test_config_path() self.config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], + [sys.argv[0], config_file], per_device_batch_size=1, run_name="test", mesh_axes=["data"], @@ -383,8 +303,10 @@ def setUp(self): data_sharding=["data"], base_output_directory=base_output_directory, dataset_type="grain", - grain_train_files=grain_train_files, - grain_packing_type="best_fit", # Use best_fit packing + grain_file_type="parquet", + grain_train_files=grain_train_file, + grain_worker_count=1, + grain_per_worker_buffer_size=1, tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), enable_checkpointing=False, ) @@ -400,12 +322,15 @@ def setUp(self): self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) -class GrainParquetProcessingTest(unittest.TestCase): +class GrainTFRecordProcessingTest(GrainBaseProcessingTest, unittest.TestCase): + """Test grain data processing with TFRecord format. + + In decoupled mode, reads directly from GCS. Otherwise, reads from GCSFUSE mounted path + """ @classmethod def setUpClass(cls): super().setUpClass() - mount_gcsfuse() def setUp(self): super().setUp() @@ -416,23 +341,24 @@ def setUp(self): dataset_root = get_test_dataset_path() grain_train_file = os.path.join( dataset_root, - "hf", "c4", - "c4-train-00000-of-01637.parquet", + "en", + "3.0.1", + "c4-train.tfrecord-00000-of-01024", ) base_output_directory = get_test_base_output_directory() - config_file = get_test_config_path() else: grain_train_file = os.path.join( temp_dir, "gcsfuse", - "hf", "c4", - "c4-train-00000-of-01637.parquet", + "en", + "3.0.1", + "c4-train.tfrecord-00000-of-01024", ) base_output_directory = "gs://max-experiments/" - config_file = get_test_config_path() + config_file = get_test_config_path() self.config = pyconfig.initialize( [sys.argv[0], config_file], per_device_batch_size=1, @@ -442,7 +368,7 @@ def setUp(self): data_sharding=["data"], base_output_directory=base_output_directory, dataset_type="grain", - grain_file_type="parquet", + grain_file_type="tfrecord", grain_train_files=grain_train_file, grain_worker_count=1, grain_per_worker_buffer_size=1, @@ -460,71 +386,6 @@ def setUp(self): ) self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) - def test_train_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - # For training we pack multiple short examples in one example. - # *_position and *_segmentation indicate the boundaries. - batch = next(self.train_iter) - self.assertEqual( - {k: list(v.shape) for k, v in batch.items()}, - { - "inputs": expected_shape, - "inputs_position": expected_shape, - "inputs_segmentation": expected_shape, - "targets": expected_shape, - "targets_position": expected_shape, - "targets_segmentation": expected_shape, - }, - ) - - def test_batch_determinism(self): - batch1 = next(self.train_iter) - train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) - batch2 = next(train_iter) - self.assertTrue((batch1["inputs"] == batch2["inputs"]).all()) - self.assertTrue((batch1["targets"] == batch2["targets"]).all()) - self.assertTrue((batch1["inputs_segmentation"] == batch2["inputs_segmentation"]).all()) - self.assertTrue((batch1["targets_segmentation"] == batch2["targets_segmentation"]).all()) - self.assertTrue((batch1["inputs_position"] == batch2["inputs_position"]).all()) - self.assertTrue((batch1["targets_position"] == batch2["targets_position"]).all()) - - def test_for_loop_repeatable(self): - def get_first_batch(iterator): - batch = None - for batch in iterator: - break - return batch - - train_batch1 = get_first_batch(self.train_iter) - train_batch2 = get_first_batch(self.train_iter) - self.assertTrue((train_batch1["inputs"] == train_batch2["inputs"]).all()) # pytype: disable=unsupported-operands - self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all()) # pytype: disable=unsupported-operands - - -def mount_gcsfuse(): - """ - Mounts a GCS bucket (gs://maxtext-dataset) to a local directory (/tmp/gcsfuse) - using gcsfuse if not already mounted. - """ - - if is_decoupled(): - return # No-op when decoupled. - temp_dir = tempfile.gettempdir() - mount_path = os.path.join(temp_dir, "gcsfuse") - - # Only mount if the directory is empty or not present - if not os.path.isdir(mount_path) or not os.listdir(mount_path): - script_path = os.path.join(MAXTEXT_REPO_ROOT, "setup_gcsfuse.sh") - if not os.path.isfile(script_path): - raise FileNotFoundError(script_path) - - exit_code = subprocess.call( - ["bash", script_path, "DATASET_GCS_BUCKET=maxtext-dataset", f"MOUNT_PATH={os.path.join(temp_dir, 'gcsfuse')}"] - ) - if exit_code != os.EX_OK: - raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") - if __name__ == "__main__": - mount_gcsfuse() unittest.main()