Skip to content

Commit 7c68a9d

Browse files
Merge pull request #3957 from AI-Hypercomputer:aireen/rm_tf_test
PiperOrigin-RevId: 925045380
2 parents bfcaf14 + 7da371c commit 7c68a9d

10 files changed

Lines changed: 55 additions & 67 deletions

src/maxtext/common/checkpointing.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,7 @@ def _default_for_sds(sds):
182182
def _make():
183183
if "key" in str(sds.dtype):
184184
base = jax.random.key(0)
185-
return (
186-
base
187-
if sds.shape == ()
188-
else jax.random.split(base, int(np.prod(sds.shape))).reshape(
189-
sds.shape
190-
)
191-
)
185+
return base if sds.shape == () else jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape)
192186
return jnp.zeros(sds.shape, dtype=sds.dtype)
193187

194188
sharding = getattr(sds, "sharding", None)
@@ -208,9 +202,7 @@ def _populate_pure_dict_from_partial(abstract_pure, partial_concrete):
208202
return {
209203
k: _populate_pure_dict_from_partial(
210204
v,
211-
partial_concrete.get(k)
212-
if isinstance(partial_concrete, dict)
213-
else None,
205+
partial_concrete.get(k) if isinstance(partial_concrete, dict) else None,
214206
)
215207
for k, v in abstract_pure.items()
216208
}
@@ -243,29 +235,21 @@ def _load_linen_checkpoint_into_nnx(
243235
)
244236
)
245237
restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract)
246-
restored = ocp.args.PyTreeRestore(
247-
item=linen_abstract, restore_args=restore_args, partial_restore=True
248-
)
238+
restored = ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True)
249239
restored = ckptr.restore(epath.Path(path), args=restored)
250240
partial_nnx = train_state_nnx.from_linen_checkpoint_dict(restored)
251241
return _populate_pure_dict_from_partial(nnx_abstract_pure, partial_nnx)
252242

253243

254244
def _rebuild_nnx_with_values(abstract_nnx_state, concrete_weights):
255245
"""Fills each Variable in `abstract_nnx_state` with the matching restored array."""
256-
leaves, treedef = jax.tree_util.tree_flatten(
257-
abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable)
258-
)
246+
leaves, treedef = jax.tree_util.tree_flatten(abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable))
259247
concrete = jax.tree_util.tree_leaves(concrete_weights)
260248
if len(leaves) != len(concrete):
261249
raise ValueError(
262-
f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs"
263-
f" {len(concrete)} restored."
250+
f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs" f" {len(concrete)} restored."
264251
)
265-
new_leaves = [
266-
v.replace(value=a) if isinstance(v, nnx.Variable) else a
267-
for v, a in zip(leaves, concrete)
268-
]
252+
new_leaves = [v.replace(value=a) if isinstance(v, nnx.Variable) else a for v, a in zip(leaves, concrete)]
269253
return jax.tree_util.tree_unflatten(treedef, new_leaves)
270254

271255

@@ -284,9 +268,7 @@ def _load_linen_params_into_nnx(
284268
NNX params Variables.
285269
"""
286270
max_logging.log(f"Restoring Linen-layout params into NNX state at {path}")
287-
linen_abstract = train_state_nnx.to_linen_checkpoint_dict(
288-
{"model": nnx_params_abstract.to_pure_dict()}
289-
)
271+
linen_abstract = train_state_nnx.to_linen_checkpoint_dict({"model": nnx_params_abstract.to_pure_dict()})
290272
ckptr = ocp.Checkpointer(
291273
ocp.PyTreeCheckpointHandler(
292274
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
@@ -298,13 +280,9 @@ def _load_linen_params_into_nnx(
298280
restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract)
299281
restored = ckptr.restore(
300282
epath.Path(path),
301-
args=ocp.args.PyTreeRestore(
302-
item=linen_abstract, restore_args=restore_args, partial_restore=True
303-
),
304-
)
305-
return _rebuild_nnx_with_values(
306-
nnx_params_abstract, restored["params"]["params"]
283+
args=ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True),
307284
)
285+
return _rebuild_nnx_with_values(nnx_params_abstract, restored["params"]["params"])
308286

309287

310288
def _load_full_state_from_path(
@@ -388,7 +366,7 @@ def create_orbax_checkpoint_manager(
388366
enable_checkpointing: bool,
389367
use_async: bool,
390368
save_interval_steps: int,
391-
dataset_type: None | str = "tfds",
369+
dataset_type: None | str = None,
392370
orbax_logger: Any = None, # pytype: disable=attribute-error
393371
use_ocdbt: bool = True,
394372
use_zarr3: bool = True,
@@ -421,7 +399,7 @@ def create_orbax_checkpoint_manager(
421399
)
422400
}
423401

424-
if dataset_type == "grain":
402+
if dataset_type is not None and dataset_type == "grain":
425403
item_names += ("iter",)
426404
item_handlers["iter"] = GrainCheckpointHandler()
427405

@@ -798,9 +776,7 @@ def map_to_pspec(data):
798776
checkpoint_manager,
799777
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
800778
):
801-
checkpoint_path = str(
802-
checkpoint_manager.directory / str(step) / "items"
803-
)
779+
checkpoint_path = str(checkpoint_manager.directory / str(step) / "items")
804780
restored_nnx = _load_linen_checkpoint_into_nnx(
805781
checkpoint_path,
806782
abstract_unboxed_pre_state,
@@ -831,9 +807,7 @@ def map_to_pspec(data):
831807
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
832808
):
833809
return (
834-
checkpoint_manager.restore(
835-
step, args=Composite(state=checkpoint_args)
836-
).state,
810+
checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state,
837811
None,
838812
)
839813
# Case 2: Matches if dataset type is "grain" and the data iterator is not a

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,6 +3040,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
30403040
if self.eval_interval > 0 and not self.grain_eval_files:
30413041
raise ValueError("Please specify grain_eval_files or set eval_interval to <=0.")
30423042
elif self.dataset_type == DatasetType.TFDS:
3043+
logger.warning(
3044+
"tfds pipeline is deprecated. Use dataset_type=grain, grain_file_type=tfrecord, and provide grain_train_files."
3045+
)
30433046
if not self.dataset_name:
30443047
raise ValueError("dataset_name can't be empty when dataset_type=tfds")
30453048
if self.eval_interval > 0 and not self.eval_split:

src/maxtext/input_pipeline/input_pipeline_interface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525
from maxtext.input_pipeline.hf_data_processing import make_hf_eval_iterator
2626
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_train_iterator
2727
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_eval_iterator
28-
from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator
29-
from maxtext.input_pipeline.tfds_data_processing import make_tfds_eval_iterator
30-
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator
31-
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_eval_iterator
3228
from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator
3329
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
3430
from maxtext.utils import max_logging
@@ -71,12 +67,16 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
7167
eval_iterator = SyntheticDataIterator(config, mesh) if config.eval_interval > 0 else None
7268
return SyntheticDataIterator(config, mesh), eval_iterator
7369
dataset_type_to_train_eval_iterator = {
74-
"tfds": (make_tfds_train_iterator, make_tfds_eval_iterator),
7570
"grain": (make_grain_train_iterator, make_grain_eval_iterator),
7671
"hf": (make_hf_train_iterator, make_hf_eval_iterator),
77-
"c4_mlperf": (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator),
7872
"olmo_grain": (make_olmo_grain_train_iterator, make_olmo_grain_eval_iterator),
7973
}
74+
if config.dataset_type in ("tfds", "c4_mlperf"):
75+
from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator, make_tfds_eval_iterator # pylint: disable=import-outside-toplevel
76+
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator # pylint: disable=import-outside-toplevel
77+
78+
dataset_type_to_train_eval_iterator["tfds"] = (make_tfds_train_iterator, make_tfds_eval_iterator)
79+
dataset_type_to_train_eval_iterator["c4_mlperf"] = (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator)
8080

8181
# Collect train and eval iterators
8282
if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf", "olmo_grain"]:

src/maxtext/input_pipeline/multihost_dataloading.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,14 @@
2525
import json
2626

2727
from etils import epath
28-
import tensorflow as tf # pylint: disable=g-import-not-at-top
28+
29+
try:
30+
import tensorflow as tf
31+
32+
_TF_RETRYABLE_ERRORS = (tf.errors.FailedPreconditionError,)
33+
except ImportError:
34+
tf = None # type: ignore[assignment]
35+
_TF_RETRYABLE_ERRORS = ()
2936

3037
import numpy as np
3138

@@ -74,14 +81,14 @@ class MultiHostDataLoadIterator:
7481

7582
def __init__(
7683
self,
77-
dataloader: tf.data.Dataset | Iterable,
84+
dataloader: Iterable,
7885
global_mesh: Mesh,
7986
generate_padding_batch: bool = False,
8087
expansion_loading_factor_for_grain: int = -1,
8188
):
8289
self.global_mesh = global_mesh
8390
self.dataloader = dataloader
84-
if isinstance(self.dataloader, tf.data.Dataset):
91+
if hasattr(self.dataloader, "as_numpy_iterator"):
8592
self.local_iterator = self.dataloader.as_numpy_iterator()
8693
elif isinstance(self.dataloader, Iterable):
8794
self.local_iterator = iter(self.dataloader)
@@ -93,7 +100,7 @@ def __init__(
93100
self.expansion_loading_factor_for_grain = expansion_loading_factor_for_grain
94101

95102
def reset(self):
96-
if isinstance(self.dataloader, tf.data.Dataset):
103+
if hasattr(self.dataloader, "as_numpy_iterator"):
97104
self.local_iterator = self.dataloader.as_numpy_iterator()
98105
elif isinstance(self.dataloader, Iterable):
99106
self.local_iterator = iter(self.dataloader)
@@ -132,7 +139,7 @@ def _get_next_batch_sharded(self) -> jax.Array:
132139
local_data_list.append(next_batch)
133140
local_data = jtu.tree_map(lambda *xs: np.concatenate(xs, axis=0), *local_data_list)
134141
break # exit the loop on success
135-
except tf.errors.FailedPreconditionError as e:
142+
except _TF_RETRYABLE_ERRORS as e:
136143
max_logging.log(f"Failed to get next data batch due to {e}, retrying")
137144
time.sleep(SLEEP_TIME)
138145
except StopIteration as e:
@@ -188,7 +195,7 @@ def __init__(self, get_ds_fn, preprocessing_fn, global_shape, checkpoint_path, e
188195
def reset(self):
189196
ds = self.get_ds_fn(dataloading_host_index=jax.process_index(), dataloading_host_count=jax.process_count())
190197
dataloader = self.preprocessing_fn(dataset=ds)
191-
if isinstance(dataloader, tf.data.Dataset):
198+
if hasattr(dataloader, "as_numpy_iterator"):
192199
self.iterator = dataloader.as_numpy_iterator()
193200
elif isinstance(dataloader, Iterable):
194201
self.iterator = iter(dataloader)

tests/end_to_end/tpu/test_convergence_1b_params.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass.
1818
export STEPS=20400 # Run for 20B tokens for a 1B sized mode for "chinchilla" scaling https://arxiv.org/abs/2203.15556
1919
export EVAL_STEPS=160
2020
export EVAL_INTERVAL=100
21-
export DATASET_TYPE=tfds
21+
export DATASET_TYPE=grain
2222
export MTP_NUM_LAYERS=0 # Disable MTP by default
2323
export PER_DEVICE_BATCH_SIZE=8.0 # With the default learning rate (3e-4) this should have global batch of 512, with 2k sequence length (1M global batch in tokens)
2424

tests/integration/checkpoint_compatibility_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@
2626
"""
2727

2828
from datetime import datetime
29+
import importlib.util
2930
import json
3031
import os
3132
import pytest
3233
from maxtext.trainers.pre_train.train import main as train_main
34+
35+
pytestmark = pytest.mark.skipif(
36+
importlib.util.find_spec("tensorflow") is None,
37+
reason="tensorflow not installed; skip testing checkpoint compatibility between tfds and grain",
38+
)
3339
from maxtext.utils.globals import MAXTEXT_REPO_ROOT
3440
from tests.integration.checkpointing_test import get_checkpointing_command
3541

tests/unit/attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def fake_to_nnx(*args, **kwargs): # pylint: disable=unused-argument
309309
context_parallel_strategy="ring",
310310
context_parallel_load_balance=False,
311311
packing=True,
312-
dataset_type="tfds",
312+
dataset_type="grain",
313313
max_segments_per_seq=4,
314314
head_dim=2,
315315
attention_kernel="cudnn_flash_te",

tests/unit/multihost_dataloading_test.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=missing-module-docstring, missing-function-docstring
16+
import itertools
1617
import sys
1718
import unittest
1819

@@ -23,9 +24,6 @@
2324
import jax
2425
from jax.sharding import Mesh
2526
from jax.experimental import mesh_utils
26-
from jax.sharding import PartitionSpec
27-
28-
import tensorflow as tf
2927

3028
from maxtext.configs import pyconfig
3129
from maxtext.input_pipeline import multihost_dataloading
@@ -51,16 +49,14 @@ def setUp(self):
5149
data_sharding=["data"],
5250
enable_checkpointing=False,
5351
)
54-
global_data_shape = PartitionSpec(batch_size, config.max_target_length)
5552
mesh_shape_1d = (len(jax.devices()),)
5653
self.mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes)
57-
# creating 2 batches of data
58-
global_data = np.arange(np.prod(global_data_shape) * 2).reshape((batch_size * 2, config.max_target_length))
59-
60-
dataset = tf.data.Dataset.from_tensor_slices(global_data)
61-
dataset = dataset.repeat()
62-
dataset = dataset.batch(batch_size)
63-
self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh)
54+
# Create 2 distinct batches and cycle through them infinitely.
55+
global_data = np.arange(batch_size * 2 * config.max_target_length, dtype=np.int32).reshape(
56+
(batch_size * 2, config.max_target_length)
57+
)
58+
data_batches = [global_data[:batch_size], global_data[batch_size:]]
59+
self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(itertools.cycle(data_batches), self.mesh)
6460

6561
@pytest.mark.tpu_only
6662
def test_batch_sharded_data_pipeline(self):

tests/unit/tfds_data_processing_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
import sys
1818
import unittest
1919

20+
import pytest
21+
2022
import jax
2123
from jax.sharding import Mesh
2224
from jax.experimental import mesh_utils
2325

24-
import tensorflow as tf
25-
import tensorflow_datasets as tfds
26+
tf = pytest.importorskip("tensorflow")
27+
tfds = pytest.importorskip("tensorflow_datasets")
2628

2729
from maxtext.configs import pyconfig
2830
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT

tests/unit/train_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class MockConfig:
3232
quantization: str = ""
3333
gradient_accumulation_steps: int = 1
3434
packing: bool = False
35-
dataset_type: str = "tfds"
35+
dataset_type: str = "synthetic"
3636

3737
# Fields needed for create_training_optimizer
3838
opt_type: str = "adamw"

0 commit comments

Comments
 (0)