Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_pathways_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ jobs:
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not CompileThenLoad" --durations=0
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not AotJaxprIdenticalTest and not CompileThenLoad and not test_diloco_two_slices" --durations=0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Skipping `test_diloco_two_slices` helps stabilize the CI, but it would be beneficial to link a tracking issue or add a TODO explaining why this test is being skipped and if it's intended to be re-enabled later.

env:
PYTHONPATH: "${{ github.workspace }}/src"
services:
Expand Down
18 changes: 18 additions & 0 deletions .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ name: Run Tests Against MaxText Package
on:
workflow_call:
inputs:
flavor:
description: 'Test flavor name (e.g. cpu-unit, tpu-unit) - used for artifact naming'
required: false
type: string
default: ''
device_type:
required: true
type: string
Expand Down Expand Up @@ -164,6 +169,11 @@ jobs:
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
$PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist
SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto"
# On scheduled runs, record per-shard durations so future splits balance by time (LPT).
# Merge artifacts offline and commit as .test_durations at repo root.
if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Recording per-shard durations on scheduled runs is an excellent strategy for optimizing CI load balancing. Merging these artifacts into a central `.test_durations` file will allow `pytest-split` to balance future runs more effectively using the Least Progress Time (LPT) algorithm.

SPLIT_ARGS="${SPLIT_ARGS} --store-durations --durations-path=.test_durations.${INPUTS_WORKER_GROUP}"
fi
else
SPLIT_ARGS=""
fi
Expand Down Expand Up @@ -195,3 +205,11 @@ jobs:
# If scheduled, upload to scheduled flag only. If PR, upload to regular flag only.
flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }}
verbose: true
- name: Upload test durations artifact
if: ${{ inputs.is_scheduled_run == 'true' && inputs.total_workers > 1 }}
uses: actions/upload-artifact@v4
continue-on-error: true
with:
name: test-durations-${{ inputs.flavor }}-${{ inputs.worker_group }}
path: .test_durations.*
if-no-files-found: ignore
1 change: 1 addition & 0 deletions .github/workflows/run_tests_coordinator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:

uses: ./.github/workflows/run_tests_against_package.yml
with:
flavor: ${{ inputs.flavor }}
# Infrastructure Mapping
device_type: >-
${{ fromJSON('{
Expand Down
79 changes: 48 additions & 31 deletions tests/unit/train_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
This module contains unit tests for `train_compile.py`, ensuring that various
model configurations and parallelism strategies can be successfully compiled
for different hardware topologies.

These tests exercise the compilation pipeline only, not numerical correctness,
so most use `override_model_config=true` with a reduced `base_num_decoder_layers`
to keep CPU compile times bounded. Full-scale model correctness is covered
elsewhere. Tests that deliberately keep the full layer count do so to exercise
sharding/pipeline edge cases and are annotated inline.
"""

import unittest
Expand Down Expand Up @@ -373,6 +379,8 @@ def test_moe_dropping_bf16(self):
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=mixtral-8x7b",
"override_model_config=true",
"base_num_decoder_layers=8",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Using `override_model_config=true` and reducing `base_num_decoder_layers` to 8 is a great optimization for compilation tests. This significantly reduces the graph size and compilation time while still exercising the necessary code paths for MoE models.

"sparse_matmul=False",
"capacity_factor=1",
"per_device_batch_size=4",
Expand Down Expand Up @@ -420,6 +428,8 @@ def test_moe_megablox_bf16(self):
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=mixtral-8x7b",
"override_model_config=true",
"base_num_decoder_layers=8",
"sparse_matmul=True",
"megablox=True",
"per_device_batch_size=4",
Expand All @@ -442,6 +452,8 @@ def test_moe_megablox_ring_ep_random(self):
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=deepseek3-test",
"override_model_config=true",
"base_num_decoder_layers=8",
"sparse_matmul=True",
"megablox=True",
"per_device_batch_size=4",
Expand All @@ -466,6 +478,8 @@ def test_moe_ragged_dot_bf16(self):
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=mixtral-8x7b",
"override_model_config=true",
"base_num_decoder_layers=8",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=4",
Expand All @@ -488,6 +502,8 @@ def test_moe_dense_bf16(self):
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=mixtral-8x7b",
"override_model_config=true",
"base_num_decoder_layers=8",
"sparse_matmul=False",
"capacity_factor=-1",
"per_device_batch_size=4",
Expand Down Expand Up @@ -534,6 +550,8 @@ def test_moe_pp_bf16(self):
"use_iota_embed=true",
"compile_topology_num_slices=2",
"model_name=mixtral-8x7b",
"override_model_config=true",
"base_num_decoder_layers=8",
"sparse_matmul=False",
"capacity_factor=1",
"per_device_batch_size=4",
Expand All @@ -558,6 +576,8 @@ def test_moe_deepseek_scanned_bf16(self):
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=deepseek3-test",
"override_model_config=true",
"base_num_decoder_layers=8",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=2",
Expand Down Expand Up @@ -606,6 +626,8 @@ def test_moe_deepseek_with_device_limit(self):
"use_iota_embed=true",
"compile_topology_num_slices=1",
"model_name=deepseek3-test",
"override_model_config=true",
"base_num_decoder_layers=8",
"sparse_matmul=True",
"megablox=False",
"per_device_batch_size=1",
Expand All @@ -620,14 +642,16 @@ def test_moe_deepseek_with_device_limit(self):

@pytest.mark.cpu_only
def test_moe_deepseek_pipeline_subset(self):
# Keeps the full layer count so pipeline_parallel_layers=56 exercises
# the real stage boundaries.
compiled_trainstep_file = "/tmp/test_moe_deepseek_pipeline_subset.pickle"
train_compile_main(
(
"",
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-64",
"compile_topology_num_slices=8",
"compile_topology=v5p-8",
"compile_topology_num_slices=2",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Reducing the `compile_topology` and `compile_topology_num_slices` here is appropriate for a unit/compilation test. It speeds up the CI by requiring fewer resources and less time to verify that the pipeline parallelism logic compiles correctly.

"use_iota_embed=true",
"model_name=deepseek3-test",
"megablox=True",
Expand All @@ -636,8 +660,9 @@ def test_moe_deepseek_pipeline_subset(self):
"per_device_batch_size=1",
"max_target_length=1024",
"pipeline_parallel_layers=56",
"ici_expert_parallelism=16",
"dcn_pipeline_parallelism=8",
"ici_expert_parallelism=4",
"ici_fsdp_parallelism=1",
"dcn_pipeline_parallelism=2",
)
)

Expand Down Expand Up @@ -669,22 +694,23 @@ def test_moe_llama4_17b_16e(self):
"",
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-128",
"compile_topology=v5p-16",
"compile_topology_num_slices=1",
"model_name=llama4-17b-16e",
"override_model_config=true",
"base_num_decoder_layers=4",
"per_device_batch_size=1",
"max_target_length=1024",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=True",
"ici_fsdp_parallelism=16",
"ici_tensor_parallelism=4",
"ici_fsdp_parallelism=4",
"ici_tensor_parallelism=2",
)
)

@pytest.mark.cpu_only
def test_moe_gpt_oss_20b_sparse_matmul(self):
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"
def _run_moe_gpt_oss_20b(self, suffix, matmul_args):
compiled_trainstep_file = f"/tmp/test_moe_gpt_oss_20b_{suffix}.pickle"
train_compile_main(
(
"",
Expand All @@ -693,38 +719,25 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
"compile_topology=v5p-16",
"compile_topology_num_slices=1",
"model_name=gpt-oss-20b",
"override_model_config=true",
"base_num_decoder_layers=8",
"per_device_batch_size=1",
"max_target_length=1024",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=True",
"sparse_matmul=True",
"megablox=True",
"attention=flash",
*matmul_args,
)
)

@pytest.mark.cpu_only
def test_moe_gpt_oss_20b_sparse_matmul(self):
self._run_moe_gpt_oss_20b("sparse_matmul", ["sparse_matmul=True", "megablox=True"])

@pytest.mark.cpu_only
def test_moe_gpt_oss_20b_dense_matmul(self):
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_dense_matmul.pickle"
train_compile_main(
(
"",
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-16",
"compile_topology_num_slices=1",
"model_name=gpt-oss-20b",
"per_device_batch_size=1",
"max_target_length=1024",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=True",
"sparse_matmul=False",
"capacity_factor=-1",
"attention=flash",
)
)
self._run_moe_gpt_oss_20b("dense_matmul", ["sparse_matmul=False", "capacity_factor=-1"])

@pytest.mark.cpu_only
def test_gpt3_6b(self):
Expand Down Expand Up @@ -769,6 +782,8 @@ def test_qwen3_next(self):
"compile_topology=v5p-256",
"compile_topology_num_slices=1",
"model_name=qwen3-next-80b-a3b",
"override_model_config=true",
"base_num_decoder_layers=8",
"per_device_batch_size=1",
"max_target_length=1024",
)
Expand Down Expand Up @@ -867,6 +882,8 @@ def test_olmo3_7b(self):
"compile_topology=v5p-8",
"compile_topology_num_slices=1",
"model_name=olmo3-7b",
"override_model_config=true",
"base_num_decoder_layers=8",
"per_device_batch_size=1",
"scan_layers=True",
"max_target_length=1024",
Expand Down
Loading