Skip to content

Commit 4b82901

Browse files
committed
Switch from pytest-split to pytest-xdist for parallel test execution
Previously, CPU tests are distributed across multiple workers using pytest-split, which assigns the same number of tests to each worker. However, since the runtime of tests is different, some workers end up finishing fast and stand idle while others take a long time, so we're not utilizing the workers fully. This change replaces the use of pytest-split with pytest-xdist which dynamically assigns work to workers.
1 parent c2574ab commit 4b82901

3 files changed

Lines changed: 12 additions & 4 deletions

File tree

.github/workflows/build_and_test_maxtext.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
fail-fast: false # don't cancel all jobs on failure
5252
matrix:
5353
image_type: ["py312"]
54-
worker_group: [1, 2, 3, 4]
54+
worker_group: [1, 2]
5555
with:
5656
device_type: cpu
5757
device_name: X64
@@ -63,7 +63,7 @@ jobs:
6363
container_resource_option: "--privileged"
6464
is_scheduled_run: ${{ github.event_name == 'schedule' }}
6565
worker_group: ${{ matrix.worker_group }}
66-
total_workers: 4
66+
total_workers: 2
6767

6868
maxtext_tpu_unit_tests:
6969
needs: build_and_upload_maxtext_package

.github/workflows/run_tests_against_package.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ jobs:
7171
TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }}
7272
TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }}
7373
MAXTEXT_PACKAGE_EXTRA: ${{ inputs.device_type == 'cpu' && 'tpu' || inputs.device_type }}
74+
ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency)
7475
options: ${{ inputs.container_resource_option }}
7576
steps:
7677
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -107,6 +108,7 @@ jobs:
107108
if [ "${{ inputs.device_type }}" != "cuda12" ]; then
108109
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
109110
fi
111+
# Use pytest-split to statically split tests across runners, and pytest-xdist to dynamically split across processes within each runner
112+
[ "${{ inputs.total_workers }}" -gt 1 ] && .venv/bin/python3 -m pip install --quiet pytest-split pytest-xdist && SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }} -n auto" || SPLIT_ARGS=""
110113
# TODO: Fix the skipped tests and remove the deselect flags
111-
[ "${{ inputs.total_workers }}" -gt 1 ] && .venv/bin/python3 -m pip install --quiet pytest-split && SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }}" || SPLIT_ARGS=""
112114
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/tokenizer_test.py::TokenizerTest::test_detokenize" $SPLIT_ARGS

tests/grain_data_processing_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def setUp(self):
6868
)
6969
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
7070

71+
@pytest.mark.cpu_only
7172
def test_train_ds(self):
7273
expected_shape = [jax.device_count(), self.config.max_target_length]
7374
# For training we pack multiple short examples in one example.
@@ -84,7 +85,7 @@ def test_train_ds(self):
8485
"targets_segmentation": expected_shape,
8586
},
8687
)
87-
88+
@pytest.mark.cpu_only
8889
def test_batch_determinism(self):
8990
batch1 = next(self.train_iter)
9091
train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
@@ -96,6 +97,7 @@ def test_batch_determinism(self):
9697
self.assertTrue((batch1["inputs_position"] == batch2["inputs_position"]).all())
9798
self.assertTrue((batch1["targets_position"] == batch2["targets_position"]).all())
9899

100+
@pytest.mark.cpu_only
99101
def test_for_loop_repeatable(self):
100102
def get_first_batch(iterator):
101103
batch = None
@@ -223,6 +225,7 @@ def setUp(self):
223225
"and it affects batch determinism at first."
224226
)
225227
)
228+
@pytest.mark.cpu_only
226229
def test_batch_determinism(self):
227230
super().test_batch_determinism()
228231

@@ -264,6 +267,7 @@ def setUp(self):
264267
)
265268
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
266269

270+
@pytest.mark.cpu_only
267271
def test_train_ds(self):
268272
expected_shape = [jax.device_count(), self.config.max_target_length]
269273
# For training we pack multiple short examples in one example.
@@ -281,6 +285,7 @@ def test_train_ds(self):
281285
},
282286
)
283287

288+
@pytest.mark.cpu_only
284289
def test_batch_determinism(self):
285290
batch1 = next(self.train_iter)
286291
train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
@@ -292,6 +297,7 @@ def test_batch_determinism(self):
292297
self.assertTrue((batch1["inputs_position"] == batch2["inputs_position"]).all())
293298
self.assertTrue((batch1["targets_position"] == batch2["targets_position"]).all())
294299

300+
@pytest.mark.cpu_only
295301
def test_for_loop_repeatable(self):
296302
def get_first_batch(iterator):
297303
batch = None

0 commit comments

Comments
 (0)