diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 17bfac932a..8efd48a065 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -182,13 +182,7 @@ def _default_for_sds(sds): def _make(): if "key" in str(sds.dtype): base = jax.random.key(0) - return ( - base - if sds.shape == () - else jax.random.split(base, int(np.prod(sds.shape))).reshape( - sds.shape - ) - ) + return base if sds.shape == () else jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape) return jnp.zeros(sds.shape, dtype=sds.dtype) sharding = getattr(sds, "sharding", None) @@ -208,9 +202,7 @@ def _populate_pure_dict_from_partial(abstract_pure, partial_concrete): return { k: _populate_pure_dict_from_partial( v, - partial_concrete.get(k) - if isinstance(partial_concrete, dict) - else None, + partial_concrete.get(k) if isinstance(partial_concrete, dict) else None, ) for k, v in abstract_pure.items() } @@ -243,9 +235,7 @@ def _load_linen_checkpoint_into_nnx( ) ) restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract) - restored = ocp.args.PyTreeRestore( - item=linen_abstract, restore_args=restore_args, partial_restore=True - ) + restored = ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True) restored = ckptr.restore(epath.Path(path), args=restored) partial_nnx = train_state_nnx.from_linen_checkpoint_dict(restored) return _populate_pure_dict_from_partial(nnx_abstract_pure, partial_nnx) @@ -253,19 +243,13 @@ def _load_linen_checkpoint_into_nnx( def _rebuild_nnx_with_values(abstract_nnx_state, concrete_weights): """Fills each Variable in `abstract_nnx_state` with the matching restored array.""" - leaves, treedef = jax.tree_util.tree_flatten( - abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + leaves, treedef = jax.tree_util.tree_flatten(abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) concrete = jax.tree_util.tree_leaves(concrete_weights) if len(leaves) != len(concrete): raise ValueError( - f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs" - f" {len(concrete)} restored." + f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs" f" {len(concrete)} restored." ) - new_leaves = [ - v.replace(value=a) if isinstance(v, nnx.Variable) else a - for v, a in zip(leaves, concrete) - ] + new_leaves = [v.replace(value=a) if isinstance(v, nnx.Variable) else a for v, a in zip(leaves, concrete)] return jax.tree_util.tree_unflatten(treedef, new_leaves) @@ -284,9 +268,7 @@ def _load_linen_params_into_nnx( NNX params Variables. """ max_logging.log(f"Restoring Linen-layout params into NNX state at {path}") - linen_abstract = train_state_nnx.to_linen_checkpoint_dict( - {"model": nnx_params_abstract.to_pure_dict()} - ) + linen_abstract = train_state_nnx.to_linen_checkpoint_dict({"model": nnx_params_abstract.to_pure_dict()}) ckptr = ocp.Checkpointer( ocp.PyTreeCheckpointHandler( restore_concurrent_gb=checkpoint_storage_concurrent_gb, @@ -298,13 +280,9 @@ def _load_linen_params_into_nnx( restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract) restored = ckptr.restore( epath.Path(path), - args=ocp.args.PyTreeRestore( - item=linen_abstract, restore_args=restore_args, partial_restore=True - ), - ) - return _rebuild_nnx_with_values( - nnx_params_abstract, restored["params"]["params"] + args=ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True), ) + return _rebuild_nnx_with_values(nnx_params_abstract, restored["params"]["params"]) def _load_full_state_from_path( @@ -388,7 +366,7 @@ def create_orbax_checkpoint_manager( enable_checkpointing: bool, use_async: bool, save_interval_steps: int, - dataset_type: None | str = "tfds", + dataset_type: None | str = None, orbax_logger: Any = None, # pytype: disable=attribute-error use_ocdbt: bool = True, use_zarr3: bool = True, @@ -421,7 +399,7 @@ def create_orbax_checkpoint_manager( ) } - if dataset_type == "grain": + if dataset_type is not None and dataset_type == "grain": item_names += ("iter",) item_handlers["iter"] = GrainCheckpointHandler() @@ -798,9 +776,7 @@ def map_to_pspec(data): checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager), ): - checkpoint_path = str( - checkpoint_manager.directory / str(step) / "items" - ) + checkpoint_path = str(checkpoint_manager.directory / str(step) / "items") restored_nnx = _load_linen_checkpoint_into_nnx( checkpoint_path, abstract_unboxed_pre_state, @@ -831,9 +807,7 @@ def map_to_pspec(data): (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager), ): return ( - checkpoint_manager.restore( - step, args=Composite(state=checkpoint_args) - ).state, + checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, None, ) # Case 2: Matches if dataset type is "grain" and the data iterator is not a diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 2c089aee3a..22284d2ce6 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -3031,6 +3031,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.eval_interval > 0 and not self.grain_eval_files: raise ValueError("Please specify grain_eval_files or set eval_interval to <=0.") elif self.dataset_type == DatasetType.TFDS: + logger.warning( + "tfds pipeline is deprecated. Use dataset_type=grain, grain_file_type=tfrecord, and provide grain_train_files." + ) if not self.dataset_name: raise ValueError("dataset_name can't be empty when dataset_type=tfds") if self.eval_interval > 0 and not self.eval_split: diff --git a/src/maxtext/input_pipeline/input_pipeline_interface.py b/src/maxtext/input_pipeline/input_pipeline_interface.py index 5229a1729c..4a0b37532f 100644 --- a/src/maxtext/input_pipeline/input_pipeline_interface.py +++ b/src/maxtext/input_pipeline/input_pipeline_interface.py @@ -25,10 +25,6 @@ from maxtext.input_pipeline.hf_data_processing import make_hf_eval_iterator from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_train_iterator from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_eval_iterator -from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator -from maxtext.input_pipeline.tfds_data_processing import make_tfds_eval_iterator -from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator -from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_eval_iterator from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator from maxtext.utils import max_logging @@ -71,12 +67,16 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): eval_iterator = SyntheticDataIterator(config, mesh) if config.eval_interval > 0 else None return SyntheticDataIterator(config, mesh), eval_iterator dataset_type_to_train_eval_iterator = { - "tfds": (make_tfds_train_iterator, make_tfds_eval_iterator), "grain": (make_grain_train_iterator, make_grain_eval_iterator), "hf": (make_hf_train_iterator, make_hf_eval_iterator), - "c4_mlperf": (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator), "olmo_grain": (make_olmo_grain_train_iterator, make_olmo_grain_eval_iterator), } + if config.dataset_type in ("tfds", "c4_mlperf"): + from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator, make_tfds_eval_iterator # pylint: disable=import-outside-toplevel + from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator # pylint: disable=import-outside-toplevel + + dataset_type_to_train_eval_iterator["tfds"] = (make_tfds_train_iterator, make_tfds_eval_iterator) + dataset_type_to_train_eval_iterator["c4_mlperf"] = (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator) # Collect train and eval iterators if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf", "olmo_grain"]: diff --git a/src/maxtext/input_pipeline/multihost_dataloading.py b/src/maxtext/input_pipeline/multihost_dataloading.py index 7757c27313..221a6ed338 100644 --- a/src/maxtext/input_pipeline/multihost_dataloading.py +++ b/src/maxtext/input_pipeline/multihost_dataloading.py @@ -25,7 +25,14 @@ import json from etils import epath -import tensorflow as tf # pylint: disable=g-import-not-at-top + +try: + import tensorflow as tf + + _TF_RETRYABLE_ERRORS = (tf.errors.FailedPreconditionError,) +except ImportError: + tf = None # type: ignore[assignment] + _TF_RETRYABLE_ERRORS = () import numpy as np @@ -74,14 +81,14 @@ class MultiHostDataLoadIterator: def __init__( self, - dataloader: tf.data.Dataset | Iterable, + dataloader: Iterable, global_mesh: Mesh, generate_padding_batch: bool = False, expansion_loading_factor_for_grain: int = -1, ): self.global_mesh = global_mesh self.dataloader = dataloader - if isinstance(self.dataloader, tf.data.Dataset): + if hasattr(self.dataloader, "as_numpy_iterator"): self.local_iterator = self.dataloader.as_numpy_iterator() elif isinstance(self.dataloader, Iterable): self.local_iterator = iter(self.dataloader) @@ -93,7 +100,7 @@ def __init__( self.expansion_loading_factor_for_grain = expansion_loading_factor_for_grain def reset(self): - if isinstance(self.dataloader, tf.data.Dataset): + if hasattr(self.dataloader, "as_numpy_iterator"): self.local_iterator = self.dataloader.as_numpy_iterator() elif isinstance(self.dataloader, Iterable): self.local_iterator = iter(self.dataloader) @@ -132,7 +139,7 @@ def _get_next_batch_sharded(self) -> jax.Array: local_data_list.append(next_batch) local_data = jtu.tree_map(lambda *xs: np.concatenate(xs, axis=0), *local_data_list) break # exit the loop on success - except tf.errors.FailedPreconditionError as e: + except _TF_RETRYABLE_ERRORS as e: max_logging.log(f"Failed to get next data batch due to {e}, retrying") time.sleep(SLEEP_TIME) except StopIteration as e: @@ -188,7 +195,7 @@ def __init__(self, get_ds_fn, preprocessing_fn, global_shape, checkpoint_path, e def reset(self): ds = self.get_ds_fn(dataloading_host_index=jax.process_index(), dataloading_host_count=jax.process_count()) dataloader = self.preprocessing_fn(dataset=ds) - if isinstance(dataloader, tf.data.Dataset): + if hasattr(dataloader, "as_numpy_iterator"): self.iterator = dataloader.as_numpy_iterator() elif isinstance(dataloader, Iterable): self.iterator = iter(dataloader) diff --git a/tests/end_to_end/tpu/test_convergence_1b_params.sh b/tests/end_to_end/tpu/test_convergence_1b_params.sh index 924ee2f8a0..28a7e0e22a 100644 --- a/tests/end_to_end/tpu/test_convergence_1b_params.sh +++ b/tests/end_to_end/tpu/test_convergence_1b_params.sh @@ -18,7 +18,7 @@ export LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass. export STEPS=20400 # Run for 20B tokens for a 1B sized mode for "chinchilla" scaling https://arxiv.org/abs/2203.15556 export EVAL_STEPS=160 export EVAL_INTERVAL=100 -export DATASET_TYPE=tfds +export DATASET_TYPE=grain export MTP_NUM_LAYERS=0 # Disable MTP by default export PER_DEVICE_BATCH_SIZE=8.0 # With the default learning rate (3e-4) this should have global batch of 512, with 2k sequence length (1M global batch in tokens) diff --git a/tests/integration/checkpoint_compatibility_test.py b/tests/integration/checkpoint_compatibility_test.py index 5b628695b5..27febcbd9e 100644 --- a/tests/integration/checkpoint_compatibility_test.py +++ b/tests/integration/checkpoint_compatibility_test.py @@ -26,10 +26,16 @@ """ from datetime import datetime +import importlib.util import json import os import pytest from maxtext.trainers.pre_train.train import main as train_main + +pytestmark = pytest.mark.skipif( + importlib.util.find_spec("tensorflow") is None, + reason="tensorflow not installed; skip testing checkpoint compatibility between tfds and grain", +) from maxtext.utils.globals import MAXTEXT_REPO_ROOT from tests.integration.checkpointing_test import get_checkpointing_command diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 3cd9c5d4d2..745e24002e 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -309,7 +309,7 @@ def fake_to_nnx(*args, **kwargs): # pylint: disable=unused-argument context_parallel_strategy="ring", context_parallel_load_balance=False, packing=True, - dataset_type="tfds", + dataset_type="grain", max_segments_per_seq=4, head_dim=2, attention_kernel="cudnn_flash_te", diff --git a/tests/unit/multihost_dataloading_test.py b/tests/unit/multihost_dataloading_test.py index 1899cc0b49..d4a6172141 100644 --- a/tests/unit/multihost_dataloading_test.py +++ b/tests/unit/multihost_dataloading_test.py @@ -13,6 +13,7 @@ # limitations under the License. # pylint: disable=missing-module-docstring, missing-function-docstring +import itertools import sys import unittest @@ -23,9 +24,6 @@ import jax from jax.sharding import Mesh from jax.experimental import mesh_utils -from jax.sharding import PartitionSpec - -import tensorflow as tf from maxtext.configs import pyconfig from maxtext.input_pipeline import multihost_dataloading @@ -51,16 +49,14 @@ def setUp(self): data_sharding=["data"], enable_checkpointing=False, ) - global_data_shape = PartitionSpec(batch_size, config.max_target_length) mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes) - # creating 2 batches of data - global_data = np.arange(np.prod(global_data_shape) * 2).reshape((batch_size * 2, config.max_target_length)) - - dataset = tf.data.Dataset.from_tensor_slices(global_data) - dataset = dataset.repeat() - dataset = dataset.batch(batch_size) - self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh) + # Create 2 distinct batches and cycle through them infinitely. + global_data = np.arange(batch_size * 2 * config.max_target_length, dtype=np.int32).reshape( + (batch_size * 2, config.max_target_length) + ) + data_batches = [global_data[:batch_size], global_data[batch_size:]] + self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(itertools.cycle(data_batches), self.mesh) @pytest.mark.tpu_only def test_batch_sharded_data_pipeline(self): diff --git a/tests/unit/tfds_data_processing_test.py b/tests/unit/tfds_data_processing_test.py index 6f3beaaf6f..fe1bb75f24 100644 --- a/tests/unit/tfds_data_processing_test.py +++ b/tests/unit/tfds_data_processing_test.py @@ -17,12 +17,14 @@ import sys import unittest +import pytest + import jax from jax.sharding import Mesh from jax.experimental import mesh_utils -import tensorflow as tf -import tensorflow_datasets as tfds +tf = pytest.importorskip("tensorflow") +tfds = pytest.importorskip("tensorflow_datasets") from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT diff --git a/tests/unit/train_utils_test.py b/tests/unit/train_utils_test.py index a8b9458794..cd7fb029e2 100644 --- a/tests/unit/train_utils_test.py +++ b/tests/unit/train_utils_test.py @@ -32,7 +32,7 @@ class MockConfig: quantization: str = "" gradient_accumulation_steps: int = 1 packing: bool = False - dataset_type: str = "tfds" + dataset_type: str = "synthetic" # Fields needed for create_training_optimizer opt_type: str = "adamw"