Skip to content

Commit 065c3be

Browse files
committed
migrate remaining tests off tf
1 parent 3e71df7 commit 065c3be

9 files changed

Lines changed: 40 additions & 30 deletions

src/maxtext/common/checkpointing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def create_orbax_checkpoint_manager(
236236
enable_checkpointing: bool,
237237
use_async: bool,
238238
save_interval_steps: int,
239-
dataset_type: None | str = "tfds",
239+
dataset_type: None | str = None,
240240
orbax_logger: Any = None, # pytype: disable=attribute-error
241241
use_ocdbt: bool = True,
242242
use_zarr3: bool = True,
@@ -269,7 +269,7 @@ def create_orbax_checkpoint_manager(
269269
)
270270
}
271271

272-
if dataset_type == "grain":
272+
if dataset_type is not None and dataset_type == "grain":
273273
item_names += ("iter",)
274274
item_handlers["iter"] = GrainCheckpointHandler()
275275

src/maxtext/input_pipeline/input_pipeline_interface.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525
from maxtext.input_pipeline.hf_data_processing import make_hf_eval_iterator
2626
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_train_iterator
2727
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_eval_iterator
28-
from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator
29-
from maxtext.input_pipeline.tfds_data_processing import make_tfds_eval_iterator
30-
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator
31-
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_eval_iterator
3228
from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator
3329
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
3430
from maxtext.utils import max_logging
@@ -71,12 +67,15 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
7167
eval_iterator = SyntheticDataIterator(config, mesh) if config.eval_interval > 0 else None
7268
return SyntheticDataIterator(config, mesh), eval_iterator
7369
dataset_type_to_train_eval_iterator = {
74-
"tfds": (make_tfds_train_iterator, make_tfds_eval_iterator),
7570
"grain": (make_grain_train_iterator, make_grain_eval_iterator),
7671
"hf": (make_hf_train_iterator, make_hf_eval_iterator),
77-
"c4_mlperf": (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator),
7872
"olmo_grain": (make_olmo_grain_train_iterator, make_olmo_grain_eval_iterator),
7973
}
74+
if config.dataset_type in ("tfds", "c4_mlperf"):
75+
from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator, make_tfds_eval_iterator # pylint: disable=import-outside-toplevel
76+
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator # pylint: disable=import-outside-toplevel
77+
dataset_type_to_train_eval_iterator["tfds"] = (make_tfds_train_iterator, make_tfds_eval_iterator)
78+
dataset_type_to_train_eval_iterator["c4_mlperf"] = (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator)
8079

8180
# Collect train and eval iterators
8281
if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf", "olmo_grain"]:

src/maxtext/input_pipeline/multihost_dataloading.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
import json
2626

2727
from etils import epath
28-
import tensorflow as tf # pylint: disable=g-import-not-at-top
28+
try:
29+
import tensorflow as tf
30+
_TF_RETRYABLE_ERRORS = (tf.errors.FailedPreconditionError,)
31+
except ImportError:
32+
tf = None # type: ignore[assignment]
33+
_TF_RETRYABLE_ERRORS = ()
2934

3035
import numpy as np
3136

@@ -74,14 +79,14 @@ class MultiHostDataLoadIterator:
7479

7580
def __init__(
7681
self,
77-
dataloader: tf.data.Dataset | Iterable,
82+
dataloader: Iterable,
7883
global_mesh: Mesh,
7984
generate_padding_batch: bool = False,
8085
expansion_loading_factor_for_grain: int = -1,
8186
):
8287
self.global_mesh = global_mesh
8388
self.dataloader = dataloader
84-
if isinstance(self.dataloader, tf.data.Dataset):
89+
if hasattr(self.dataloader, "as_numpy_iterator"):
8590
self.local_iterator = self.dataloader.as_numpy_iterator()
8691
elif isinstance(self.dataloader, Iterable):
8792
self.local_iterator = iter(self.dataloader)
@@ -93,7 +98,7 @@ def __init__(
9398
self.expansion_loading_factor_for_grain = expansion_loading_factor_for_grain
9499

95100
def reset(self):
96-
if isinstance(self.dataloader, tf.data.Dataset):
101+
if hasattr(self.dataloader, "as_numpy_iterator"):
97102
self.local_iterator = self.dataloader.as_numpy_iterator()
98103
elif isinstance(self.dataloader, Iterable):
99104
self.local_iterator = iter(self.dataloader)
@@ -132,7 +137,7 @@ def _get_next_batch_sharded(self) -> jax.Array:
132137
local_data_list.append(next_batch)
133138
local_data = jtu.tree_map(lambda *xs: np.concatenate(xs, axis=0), *local_data_list)
134139
break # exit the loop on success
135-
except tf.errors.FailedPreconditionError as e:
140+
except _TF_RETRYABLE_ERRORS as e:
136141
max_logging.log(f"Failed to get next data batch due to {e}, retrying")
137142
time.sleep(SLEEP_TIME)
138143
except StopIteration as e:
@@ -188,7 +193,7 @@ def __init__(self, get_ds_fn, preprocessing_fn, global_shape, checkpoint_path, e
188193
def reset(self):
189194
ds = self.get_ds_fn(dataloading_host_index=jax.process_index(), dataloading_host_count=jax.process_count())
190195
dataloader = self.preprocessing_fn(dataset=ds)
191-
if isinstance(dataloader, tf.data.Dataset):
196+
if hasattr(dataloader, "as_numpy_iterator"):
192197
self.iterator = dataloader.as_numpy_iterator()
193198
elif isinstance(dataloader, Iterable):
194199
self.iterator = iter(dataloader)

tests/end_to_end/tpu/test_convergence_1b_params.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass.
1818
export STEPS=20400 # Run for 20B tokens for a 1B sized mode for "chinchilla" scaling https://arxiv.org/abs/2203.15556
1919
export EVAL_STEPS=160
2020
export EVAL_INTERVAL=100
21-
export DATASET_TYPE=tfds
21+
export DATASET_TYPE=grain
2222
export MTP_NUM_LAYERS=0 # Disable MTP by default
2323
export PER_DEVICE_BATCH_SIZE=8.0 # With the default learning rate (3e-4) this should have global batch of 512, with 2k sequence length (1M global batch in tokens)
2424

tests/integration/checkpoint_compatibility_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@
2626
"""
2727

2828
from datetime import datetime
29+
import importlib.util
2930
import json
3031
import os
3132
import pytest
3233
from maxtext.trainers.pre_train.train import main as train_main
34+
35+
pytestmark = pytest.mark.skipif(
36+
importlib.util.find_spec("tensorflow") is None,
37+
reason="tensorflow not installed; skip testing checkpoint compatibility between tfds and grain",
38+
)
3339
from maxtext.utils.globals import MAXTEXT_REPO_ROOT
3440
from tests.integration.checkpointing_test import get_checkpointing_command
3541

tests/unit/attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def fake_to_nnx(*args, **kwargs): # pylint: disable=unused-argument
309309
context_parallel_strategy="ring",
310310
context_parallel_load_balance=False,
311311
packing=True,
312-
dataset_type="tfds",
312+
dataset_type="grain",
313313
max_segments_per_seq=4,
314314
head_dim=2,
315315
attention_kernel="cudnn_flash_te",

tests/unit/multihost_dataloading_test.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=missing-module-docstring, missing-function-docstring
16+
import itertools
1617
import sys
1718
import unittest
1819

@@ -23,9 +24,6 @@
2324
import jax
2425
from jax.sharding import Mesh
2526
from jax.experimental import mesh_utils
26-
from jax.sharding import PartitionSpec
27-
28-
import tensorflow as tf
2927

3028
from maxtext.configs import pyconfig
3129
from maxtext.input_pipeline import multihost_dataloading
@@ -51,16 +49,16 @@ def setUp(self):
5149
data_sharding=["data"],
5250
enable_checkpointing=False,
5351
)
54-
global_data_shape = PartitionSpec(batch_size, config.max_target_length)
5552
mesh_shape_1d = (len(jax.devices()),)
5653
self.mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes)
57-
# creating 2 batches of data
58-
global_data = np.arange(np.prod(global_data_shape) * 2).reshape((batch_size * 2, config.max_target_length))
59-
60-
dataset = tf.data.Dataset.from_tensor_slices(global_data)
61-
dataset = dataset.repeat()
62-
dataset = dataset.batch(batch_size)
63-
self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh)
54+
# Create 2 distinct batches and cycle through them infinitely.
55+
global_data = np.arange(batch_size * 2 * config.max_target_length, dtype=np.int32).reshape(
56+
(batch_size * 2, config.max_target_length)
57+
)
58+
data_batches = [global_data[:batch_size], global_data[batch_size:]]
59+
self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(
60+
itertools.cycle(data_batches), self.mesh
61+
)
6462

6563
@pytest.mark.tpu_only
6664
def test_batch_sharded_data_pipeline(self):

tests/unit/tfds_data_processing_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
import sys
1818
import unittest
1919

20+
import pytest
21+
2022
import jax
2123
from jax.sharding import Mesh
2224
from jax.experimental import mesh_utils
2325

24-
import tensorflow as tf
25-
import tensorflow_datasets as tfds
26+
tf = pytest.importorskip("tensorflow")
27+
tfds = pytest.importorskip("tensorflow_datasets")
2628

2729
from maxtext.configs import pyconfig
2830
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT

tests/unit/train_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class MockConfig:
3232
quantization: str = ""
3333
gradient_accumulation_steps: int = 1
3434
packing: bool = False
35-
dataset_type: str = "tfds"
35+
dataset_type: str = "synthetic"
3636

3737
# Fields needed for create_training_optimizer
3838
opt_type: str = "adamw"

0 commit comments

Comments
 (0)