Skip to content

Commit 9b31a12

Browse files
committed
Speed up cpu-unit CI
1 parent 1907615 commit 9b31a12

4 files changed

Lines changed: 54 additions & 34 deletions

File tree

.github/workflows/run_pathways_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ jobs:
102102
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
103103
export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext
104104
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
105-
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not CompileThenLoad" --durations=0
105+
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not AotJaxprIdenticalTest and not CompileThenLoad and not test_diloco_two_slices" --durations=0
106106
env:
107107
PYTHONPATH: "${{ github.workspace }}/src"
108108
services:

.github/workflows/run_tests_against_package.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ name: Run Tests Against MaxText Package
1919
on:
2020
workflow_call:
2121
inputs:
22+
flavor:
23+
description: 'Test flavor name (e.g. cpu-unit, tpu-unit) - used for artifact naming'
24+
required: false
25+
type: string
26+
default: ''
2227
device_type:
2328
required: true
2429
type: string
@@ -164,6 +169,11 @@ jobs:
164169
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
165170
$PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist
166171
SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto"
172+
# On scheduled runs, record per-shard durations so future splits balance by time (LPT).
173+
# Merge artifacts offline and commit as .test_durations at repo root.
174+
if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then
175+
SPLIT_ARGS="${SPLIT_ARGS} --store-durations --durations-path=.test_durations.${INPUTS_WORKER_GROUP}"
176+
fi
167177
else
168178
SPLIT_ARGS=""
169179
fi
@@ -195,3 +205,11 @@ jobs:
195205
# If scheduled, upload to scheduled flag only. If PR, upload to regular flag only.
196206
flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }}
197207
verbose: true
208+
- name: Upload test durations artifact
209+
if: ${{ inputs.is_scheduled_run == 'true' && inputs.total_workers > 1 }}
210+
uses: actions/upload-artifact@v4
211+
continue-on-error: true
212+
with:
213+
name: test-durations-${{ inputs.flavor }}-${{ inputs.worker_group }}
214+
path: .test_durations.*
215+
if-no-files-found: ignore

.github/workflows/run_tests_coordinator.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ jobs:
5757
strategy:
5858
fail-fast: false
5959
matrix:
60-
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2]' || '[1]') }}
60+
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2, 3, 4]' || '[1]') }}
6161

6262
uses: ./.github/workflows/run_tests_against_package.yml
6363
with:
64+
flavor: ${{ inputs.flavor }}
6465
# Infrastructure Mapping
6566
device_type: >-
6667
${{ fromJSON('{
@@ -148,5 +149,5 @@ jobs:
148149
is_scheduled_run: ${{ inputs.is_scheduled_run }}
149150
maxtext_installed: ${{ inputs.maxtext_installed }}
150151
worker_group: ${{ matrix.worker_group }}
151-
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }}
152+
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 4 || 1 }}
152153
maxtext_sha: ${{ inputs.maxtext_sha }}

tests/unit/train_compile_test.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,8 @@ def test_moe_dropping_bf16(self):
373373
"use_iota_embed=true",
374374
"compile_topology_num_slices=1",
375375
"model_name=mixtral-8x7b",
376+
"override_model_config=true",
377+
"base_num_decoder_layers=8",
376378
"sparse_matmul=False",
377379
"capacity_factor=1",
378380
"per_device_batch_size=4",
@@ -442,6 +444,8 @@ def test_moe_megablox_ring_ep_random(self):
442444
"use_iota_embed=true",
443445
"compile_topology_num_slices=1",
444446
"model_name=deepseek3-test",
447+
"override_model_config=true",
448+
"base_num_decoder_layers=8",
445449
"sparse_matmul=True",
446450
"megablox=True",
447451
"per_device_batch_size=4",
@@ -466,6 +470,8 @@ def test_moe_ragged_dot_bf16(self):
466470
"use_iota_embed=true",
467471
"compile_topology_num_slices=1",
468472
"model_name=mixtral-8x7b",
473+
"override_model_config=true",
474+
"base_num_decoder_layers=8",
469475
"sparse_matmul=True",
470476
"megablox=False",
471477
"per_device_batch_size=4",
@@ -488,6 +494,8 @@ def test_moe_dense_bf16(self):
488494
"use_iota_embed=true",
489495
"compile_topology_num_slices=1",
490496
"model_name=mixtral-8x7b",
497+
"override_model_config=true",
498+
"base_num_decoder_layers=8",
491499
"sparse_matmul=False",
492500
"capacity_factor=-1",
493501
"per_device_batch_size=4",
@@ -606,6 +614,8 @@ def test_moe_deepseek_with_device_limit(self):
606614
"use_iota_embed=true",
607615
"compile_topology_num_slices=1",
608616
"model_name=deepseek3-test",
617+
"override_model_config=true",
618+
"base_num_decoder_layers=8",
609619
"sparse_matmul=True",
610620
"megablox=False",
611621
"per_device_batch_size=1",
@@ -626,8 +636,8 @@ def test_moe_deepseek_pipeline_subset(self):
626636
"",
627637
get_test_config_path(),
628638
f"compiled_trainstep_file={compiled_trainstep_file}",
629-
"compile_topology=v5p-64",
630-
"compile_topology_num_slices=8",
639+
"compile_topology=v5p-8",
640+
"compile_topology_num_slices=2",
631641
"use_iota_embed=true",
632642
"model_name=deepseek3-test",
633643
"megablox=True",
@@ -636,8 +646,9 @@ def test_moe_deepseek_pipeline_subset(self):
636646
"per_device_batch_size=1",
637647
"max_target_length=1024",
638648
"pipeline_parallel_layers=56",
639-
"ici_expert_parallelism=16",
640-
"dcn_pipeline_parallelism=8",
649+
"ici_expert_parallelism=4",
650+
"ici_fsdp_parallelism=1",
651+
"dcn_pipeline_parallelism=2",
641652
)
642653
)
643654

@@ -669,22 +680,23 @@ def test_moe_llama4_17b_16e(self):
669680
"",
670681
get_test_config_path(),
671682
f"compiled_trainstep_file={compiled_trainstep_file}",
672-
"compile_topology=v5p-128",
683+
"compile_topology=v5p-16",
673684
"compile_topology_num_slices=1",
674685
"model_name=llama4-17b-16e",
686+
"override_model_config=true",
687+
"base_num_decoder_layers=4",
675688
"per_device_batch_size=1",
676689
"max_target_length=1024",
677690
"dtype=bfloat16",
678691
"weight_dtype=bfloat16",
679692
"scan_layers=True",
680-
"ici_fsdp_parallelism=16",
681-
"ici_tensor_parallelism=4",
693+
"ici_fsdp_parallelism=4",
694+
"ici_tensor_parallelism=2",
682695
)
683696
)
684697

685-
@pytest.mark.cpu_only
686-
def test_moe_gpt_oss_20b_sparse_matmul(self):
687-
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"
698+
def _run_moe_gpt_oss_20b(self, suffix, matmul_args):
699+
compiled_trainstep_file = f"/tmp/test_moe_gpt_oss_20b_{suffix}.pickle"
688700
train_compile_main(
689701
(
690702
"",
@@ -693,38 +705,25 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
693705
"compile_topology=v5p-16",
694706
"compile_topology_num_slices=1",
695707
"model_name=gpt-oss-20b",
708+
"override_model_config=true",
709+
"base_num_decoder_layers=8",
696710
"per_device_batch_size=1",
697711
"max_target_length=1024",
698712
"dtype=bfloat16",
699713
"weight_dtype=bfloat16",
700714
"scan_layers=True",
701-
"sparse_matmul=True",
702-
"megablox=True",
703715
"attention=flash",
716+
*matmul_args,
704717
)
705718
)
706719

720+
@pytest.mark.cpu_only
721+
def test_moe_gpt_oss_20b_sparse_matmul(self):
722+
self._run_moe_gpt_oss_20b("sparse_matmul", ["sparse_matmul=True", "megablox=True"])
723+
707724
@pytest.mark.cpu_only
708725
def test_moe_gpt_oss_20b_dense_matmul(self):
709-
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_dense_matmul.pickle"
710-
train_compile_main(
711-
(
712-
"",
713-
get_test_config_path(),
714-
f"compiled_trainstep_file={compiled_trainstep_file}",
715-
"compile_topology=v5p-16",
716-
"compile_topology_num_slices=1",
717-
"model_name=gpt-oss-20b",
718-
"per_device_batch_size=1",
719-
"max_target_length=1024",
720-
"dtype=bfloat16",
721-
"weight_dtype=bfloat16",
722-
"scan_layers=True",
723-
"sparse_matmul=False",
724-
"capacity_factor=-1",
725-
"attention=flash",
726-
)
727-
)
726+
self._run_moe_gpt_oss_20b("dense_matmul", ["sparse_matmul=False", "capacity_factor=-1"])
728727

729728
@pytest.mark.cpu_only
730729
def test_gpt3_6b(self):
@@ -867,6 +866,8 @@ def test_olmo3_7b(self):
867866
"compile_topology=v5p-8",
868867
"compile_topology_num_slices=1",
869868
"model_name=olmo3-7b",
869+
"override_model_config=true",
870+
"base_num_decoder_layers=8",
870871
"per_device_batch_size=1",
871872
"scan_layers=True",
872873
"max_target_length=1024",

0 commit comments

Comments
 (0)