Skip to content

Commit 8086fc1

Browse files
gagikaGagik Amirkhanyan
authored andcommitted
Speed up cpu-unit CI
1 parent 1907615 commit 8086fc1

4 files changed

Lines changed: 33 additions & 35 deletions

File tree

.github/workflows/run_pathways_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ jobs:
102102
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
103103
export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext
104104
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
105-
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not CompileThenLoad" --durations=0
105+
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not AotJaxprIdenticalTest and not CompileThenLoad" --durations=0
106106
env:
107107
PYTHONPATH: "${{ github.workspace }}/src"
108108
services:

.github/workflows/run_tests_against_package.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ jobs:
164164
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
165165
$PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist
166166
SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto"
167+
# On scheduled runs, record per-shard durations so future splits balance by time (LPT).
168+
# Merge artifacts offline and commit as .test_durations at repo root.
169+
if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then
170+
SPLIT_ARGS="${SPLIT_ARGS} --store-durations --durations-path=.test_durations.${INPUTS_WORKER_GROUP}"
171+
fi
167172
else
168173
SPLIT_ARGS=""
169174
fi
@@ -195,3 +200,11 @@ jobs:
195200
# If scheduled, upload to scheduled flag only. If PR, upload to regular flag only.
196201
flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }}
197202
verbose: true
203+
- name: Upload test durations artifact
204+
if: ${{ inputs.is_scheduled_run == 'true' && inputs.total_workers > 1 }}
205+
uses: actions/upload-artifact@v4
206+
continue-on-error: true
207+
with:
208+
name: test-durations-${{ inputs.flavor }}-${{ inputs.worker_group }}
209+
path: .test_durations.*
210+
if-no-files-found: ignore

.github/workflows/run_tests_coordinator.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ jobs:
5757
strategy:
5858
fail-fast: false
5959
matrix:
60-
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2]' || '[1]') }}
60+
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2, 3, 4]' || '[1]') }}
6161

6262
uses: ./.github/workflows/run_tests_against_package.yml
6363
with:
@@ -148,5 +148,5 @@ jobs:
148148
is_scheduled_run: ${{ inputs.is_scheduled_run }}
149149
maxtext_installed: ${{ inputs.maxtext_installed }}
150150
worker_group: ${{ matrix.worker_group }}
151-
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }}
151+
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 4 || 1 }}
152152
maxtext_sha: ${{ inputs.maxtext_sha }}

tests/unit/train_compile_test.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,8 @@ def test_moe_deepseek_pipeline_subset(self):
626626
"",
627627
get_test_config_path(),
628628
f"compiled_trainstep_file={compiled_trainstep_file}",
629-
"compile_topology=v5p-64",
630-
"compile_topology_num_slices=8",
629+
"compile_topology=v5p-8",
630+
"compile_topology_num_slices=2",
631631
"use_iota_embed=true",
632632
"model_name=deepseek3-test",
633633
"megablox=True",
@@ -636,8 +636,8 @@ def test_moe_deepseek_pipeline_subset(self):
636636
"per_device_batch_size=1",
637637
"max_target_length=1024",
638638
"pipeline_parallel_layers=56",
639-
"ici_expert_parallelism=16",
640-
"dcn_pipeline_parallelism=8",
639+
"ici_expert_parallelism=8",
640+
"dcn_pipeline_parallelism=2",
641641
)
642642
)
643643

@@ -669,22 +669,22 @@ def test_moe_llama4_17b_16e(self):
669669
"",
670670
get_test_config_path(),
671671
f"compiled_trainstep_file={compiled_trainstep_file}",
672-
"compile_topology=v5p-128",
672+
"compile_topology=v5p-16",
673673
"compile_topology_num_slices=1",
674674
"model_name=llama4-17b-16e",
675675
"per_device_batch_size=1",
676676
"max_target_length=1024",
677677
"dtype=bfloat16",
678678
"weight_dtype=bfloat16",
679679
"scan_layers=True",
680-
"ici_fsdp_parallelism=16",
681-
"ici_tensor_parallelism=4",
680+
"ici_fsdp_parallelism=4",
681+
"ici_tensor_parallelism=2",
682+
"ici_expert_parallelism=2",
682683
)
683684
)
684685

685-
@pytest.mark.cpu_only
686-
def test_moe_gpt_oss_20b_sparse_matmul(self):
687-
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"
686+
def _run_moe_gpt_oss_20b(self, suffix, matmul_args):
687+
compiled_trainstep_file = f"/tmp/test_moe_gpt_oss_20b_{suffix}.pickle"
688688
train_compile_main(
689689
(
690690
"",
@@ -698,33 +698,18 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
698698
"dtype=bfloat16",
699699
"weight_dtype=bfloat16",
700700
"scan_layers=True",
701-
"sparse_matmul=True",
702-
"megablox=True",
703701
"attention=flash",
702+
*matmul_args,
704703
)
705704
)
706705

706+
@pytest.mark.cpu_only
707+
def test_moe_gpt_oss_20b_sparse_matmul(self):
708+
self._run_moe_gpt_oss_20b("sparse_matmul", ["sparse_matmul=True", "megablox=True"])
709+
707710
@pytest.mark.cpu_only
708711
def test_moe_gpt_oss_20b_dense_matmul(self):
709-
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_dense_matmul.pickle"
710-
train_compile_main(
711-
(
712-
"",
713-
get_test_config_path(),
714-
f"compiled_trainstep_file={compiled_trainstep_file}",
715-
"compile_topology=v5p-16",
716-
"compile_topology_num_slices=1",
717-
"model_name=gpt-oss-20b",
718-
"per_device_batch_size=1",
719-
"max_target_length=1024",
720-
"dtype=bfloat16",
721-
"weight_dtype=bfloat16",
722-
"scan_layers=True",
723-
"sparse_matmul=False",
724-
"capacity_factor=-1",
725-
"attention=flash",
726-
)
727-
)
712+
self._run_moe_gpt_oss_20b("dense_matmul", ["sparse_matmul=False", "capacity_factor=-1"])
728713

729714
@pytest.mark.cpu_only
730715
def test_gpt3_6b(self):
@@ -766,7 +751,7 @@ def test_qwen3_next(self):
766751
"",
767752
get_test_config_path(),
768753
f"compiled_trainstep_file={compiled_trainstep_file}",
769-
"compile_topology=v5p-256",
754+
"compile_topology=v5p-8",
770755
"compile_topology_num_slices=1",
771756
"model_name=qwen3-next-80b-a3b",
772757
"per_device_batch_size=1",

0 commit comments

Comments
 (0)