Skip to content

Commit 7f5aa98

Browse files
committed
Speed up cpu-unit CI
1 parent 1907615 commit 7f5aa98

4 files changed

Lines changed: 30 additions & 27 deletions

File tree

.github/workflows/run_pathways_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ jobs:
102102
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
103103
export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext
104104
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
105-
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not CompileThenLoad" --durations=0
105+
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not AotJaxprIdenticalTest and not CompileThenLoad" --durations=0
106106
env:
107107
PYTHONPATH: "${{ github.workspace }}/src"
108108
services:

.github/workflows/run_tests_against_package.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ name: Run Tests Against MaxText Package
1919
on:
2020
workflow_call:
2121
inputs:
22+
flavor:
23+
description: 'Test flavor name (e.g. cpu-unit, tpu-unit) - used for artifact naming'
24+
required: false
25+
type: string
26+
default: ''
2227
device_type:
2328
required: true
2429
type: string
@@ -164,6 +169,11 @@ jobs:
164169
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
165170
$PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist
166171
SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto"
172+
# On scheduled runs, record per-shard durations so future splits balance by time (LPT).
173+
# Merge artifacts offline and commit as .test_durations at repo root.
174+
if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then
175+
SPLIT_ARGS="${SPLIT_ARGS} --store-durations --durations-path=.test_durations.${INPUTS_WORKER_GROUP}"
176+
fi
167177
else
168178
SPLIT_ARGS=""
169179
fi
@@ -195,3 +205,11 @@ jobs:
195205
# If scheduled, upload to scheduled flag only. If PR, upload to regular flag only.
196206
flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }}
197207
verbose: true
208+
- name: Upload test durations artifact
209+
if: ${{ inputs.is_scheduled_run == 'true' && inputs.total_workers > 1 }}
210+
uses: actions/upload-artifact@v4
211+
continue-on-error: true
212+
with:
213+
name: test-durations-${{ inputs.flavor }}-${{ inputs.worker_group }}
214+
path: .test_durations.*
215+
if-no-files-found: ignore

.github/workflows/run_tests_coordinator.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ jobs:
5757
strategy:
5858
fail-fast: false
5959
matrix:
60-
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2]' || '[1]') }}
60+
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2, 3, 4]' || '[1]') }}
6161

6262
uses: ./.github/workflows/run_tests_against_package.yml
6363
with:
64+
flavor: ${{ inputs.flavor }}
6465
# Infrastructure Mapping
6566
device_type: >-
6667
${{ fromJSON('{
@@ -148,5 +149,5 @@ jobs:
148149
is_scheduled_run: ${{ inputs.is_scheduled_run }}
149150
maxtext_installed: ${{ inputs.maxtext_installed }}
150151
worker_group: ${{ matrix.worker_group }}
151-
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }}
152+
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 4 || 1 }}
152153
maxtext_sha: ${{ inputs.maxtext_sha }}

tests/unit/train_compile_test.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,8 @@ def test_moe_llama4_17b_16e(self):
682682
)
683683
)
684684

685-
@pytest.mark.cpu_only
686-
def test_moe_gpt_oss_20b_sparse_matmul(self):
687-
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"
685+
def _run_moe_gpt_oss_20b(self, suffix, matmul_args):
686+
compiled_trainstep_file = f"/tmp/test_moe_gpt_oss_20b_{suffix}.pickle"
688687
train_compile_main(
689688
(
690689
"",
@@ -698,33 +697,18 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
698697
"dtype=bfloat16",
699698
"weight_dtype=bfloat16",
700699
"scan_layers=True",
701-
"sparse_matmul=True",
702-
"megablox=True",
703700
"attention=flash",
701+
*matmul_args,
704702
)
705703
)
706704

705+
@pytest.mark.cpu_only
706+
def test_moe_gpt_oss_20b_sparse_matmul(self):
707+
self._run_moe_gpt_oss_20b("sparse_matmul", ["sparse_matmul=True", "megablox=True"])
708+
707709
@pytest.mark.cpu_only
708710
def test_moe_gpt_oss_20b_dense_matmul(self):
709-
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_dense_matmul.pickle"
710-
train_compile_main(
711-
(
712-
"",
713-
get_test_config_path(),
714-
f"compiled_trainstep_file={compiled_trainstep_file}",
715-
"compile_topology=v5p-16",
716-
"compile_topology_num_slices=1",
717-
"model_name=gpt-oss-20b",
718-
"per_device_batch_size=1",
719-
"max_target_length=1024",
720-
"dtype=bfloat16",
721-
"weight_dtype=bfloat16",
722-
"scan_layers=True",
723-
"sparse_matmul=False",
724-
"capacity_factor=-1",
725-
"attention=flash",
726-
)
727-
)
711+
self._run_moe_gpt_oss_20b("dense_matmul", ["sparse_matmul=False", "capacity_factor=-1"])
728712

729713
@pytest.mark.cpu_only
730714
def test_gpt3_6b(self):

0 commit comments

Comments
 (0)