-
Notifications
You must be signed in to change notification settings - Fork 542
Expand file tree
/
Copy pathtfds_data_processing_test.py
More file actions
163 lines (140 loc) · 6.11 KB
/
Copy pathtfds_data_processing_test.py
File metadata and controls
163 lines (140 loc) · 6.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-module-docstring, missing-function-docstring
import os
import sys
import unittest
import pytest
import jax
from jax.sharding import Mesh
from jax.experimental import mesh_utils
tf = pytest.importorskip("tensorflow")
tfds = pytest.importorskip("tensorflow_datasets")
from maxtext.configs import pyconfig
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
from maxtext.input_pipeline import tfds_data_processing
from maxtext.input_pipeline import input_pipeline_interface
from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory
class TfdsDataProcessingTest(unittest.TestCase):
def setUp(self):
super().setUp()
_dataset_path = os.path.join("tests", "assets", "local_datasets", "c4_en_dataset_minimal")
_base_output_directory = get_test_base_output_directory(cloud_path="gs://max-experiments/")
config_kwargs = {
"per_device_batch_size": 1,
"run_name": "test",
"mesh_axes": ["data"],
"logical_axis_rules": [["batch", "data"]],
"data_sharding": ["data"],
"base_output_directory": _base_output_directory,
"dataset_path": _dataset_path,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
"enable_checkpointing": False,
"eval_interval": 10,
"max_target_length": 128,
}
config = pyconfig.initialize([sys.argv[0], get_test_config_path()], **config_kwargs)
os.environ["TFDS_DATA_DIR"] = config.dataset_path
self.config = config
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
self.config.data_sharding,
self.config.global_batch_size_to_load,
self.config.global_batch_size_to_train_on,
self.config.max_target_length,
self.mesh,
)
self.read_config = tfds.ReadConfig(
shuffle_seed=self.config.data_shuffle_seed,
)
self.read_config.add_tfds_id = True
@property
def train_ds(self):
# pylint: disable=protected-access
if not hasattr(self.__class__, "_cached_train_ds"):
self.__class__._cached_train_ds = self._get_datasets()
return self.__class__._cached_train_ds
@property
def train_iter(self):
# pylint: disable=protected-access
if not hasattr(self.__class__, "_cached_train_iter"):
self.__class__._cached_train_iter = tfds_data_processing.make_tfds_train_iterator(
self.config, self.mesh, self.process_indices
)
return self.__class__._cached_train_iter
@property
def eval_iter(self):
# pylint: disable=protected-access
if not hasattr(self.__class__, "_cached_eval_iter"):
self.__class__._cached_eval_iter = tfds_data_processing.make_tfds_eval_iterator(
self.config, self.mesh, self.process_indices
)
return self.__class__._cached_eval_iter
def _get_datasets(self):
ds_builder = tfds.builder(self.config.dataset_name, data_dir=self.config.dataset_path)
self.read_config.input_context = tf.distribute.InputContext(
input_pipeline_id=jax.process_index(),
num_input_pipelines=jax.process_count(),
)
ds = ds_builder.as_dataset(
split="train", read_config=self.read_config, shuffle_files=self.config.enable_data_shuffling
)
return ds
def test_train_ds(self):
expected_shape = [jax.device_count(), self.config.max_target_length]
# For training we pack multiple short examples in one example.
# *_position and *_segmentation indicate the boundaries.
batch = next(self.train_iter)
self.assertEqual(
{k: list(v.shape) for k, v in batch.items()},
{
"inputs": expected_shape,
"inputs_position": expected_shape,
"inputs_segmentation": expected_shape,
"targets": expected_shape,
"targets_position": expected_shape,
"targets_segmentation": expected_shape,
},
)
def test_ds_determinism(self):
train_ds1 = self.train_ds.batch(64)
train_ds1 = next(train_ds1.as_numpy_iterator())
# reset the dataset loading
train_ds = self._get_datasets()
train_ds = train_ds.batch(64)
train_ds2 = next(train_ds.as_numpy_iterator())
self.assertCountEqual(train_ds1["tfds_id"], train_ds2["tfds_id"])
def test_batch_determinism(self):
batch1 = next(self.train_iter)
train_iter = tfds_data_processing.make_tfds_train_iterator(self.config, self.mesh, self.process_indices)
batch2 = next(train_iter)
self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs"], batch2["inputs"])))
self.assertTrue(tf.reduce_all(tf.equal(batch1["targets"], batch2["targets"])))
self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs_segmentation"], batch2["inputs_segmentation"])))
self.assertTrue(tf.reduce_all(tf.equal(batch1["targets_segmentation"], batch2["targets_segmentation"])))
self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs_position"], batch2["inputs_position"])))
self.assertTrue(tf.reduce_all(tf.equal(batch1["targets_position"], batch2["targets_position"])))
def test_for_loop_repeatable(self):
def get_first_batch(iterator):
batch = None
for batch in iterator:
break
return batch
eval_batch1 = get_first_batch(self.eval_iter)
eval_batch2 = get_first_batch(self.eval_iter)
self.assertTrue((eval_batch1["inputs"] == eval_batch2["inputs"]).all())
self.assertTrue((eval_batch1["targets"] == eval_batch2["targets"]).all())
if __name__ == "__main__":
unittest.main()