-
Notifications
You must be signed in to change notification settings - Fork 507
Speed up cpu-unit CI #3700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
gagika
wants to merge
1
commit into
main
Choose a base branch
from
agagik-ci
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Speed up cpu-unit CI #3700
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,11 @@ name: Run Tests Against MaxText Package | |
| on: | ||
| workflow_call: | ||
| inputs: | ||
| flavor: | ||
| description: 'Test flavor name (e.g. cpu-unit, tpu-unit) - used for artifact naming' | ||
| required: false | ||
| type: string | ||
| default: '' | ||
| device_type: | ||
| required: true | ||
| type: string | ||
|
|
@@ -164,6 +169,11 @@ jobs: | |
| if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then | ||
| $PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist | ||
| SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto" | ||
| # On scheduled runs, record per-shard durations so future splits balance by time (LPT). | ||
| # Merge artifacts offline and commit as .test_durations at repo root. | ||
| if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟢 Recording per-shard durations on scheduled runs is an excellent strategy for optimizing CI load balancing. Merging these artifacts into a central `.test_durations` file will allow `pytest-split` to balance future runs more effectively using the Least Progress Time (LPT) algorithm.
|
||
| SPLIT_ARGS="${SPLIT_ARGS} --store-durations --durations-path=.test_durations.${INPUTS_WORKER_GROUP}" | ||
| fi | ||
| else | ||
| SPLIT_ARGS="" | ||
| fi | ||
|
|
@@ -195,3 +205,11 @@ jobs: | |
| # If scheduled, upload to scheduled flag only. If PR, upload to regular flag only. | ||
| flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }} | ||
| verbose: true | ||
| - name: Upload test durations artifact | ||
| if: ${{ inputs.is_scheduled_run == 'true' && inputs.total_workers > 1 }} | ||
| uses: actions/upload-artifact@v4 | ||
| continue-on-error: true | ||
| with: | ||
| name: test-durations-${{ inputs.flavor }}-${{ inputs.worker_group }} | ||
| path: .test_durations.* | ||
| if-no-files-found: ignore | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,12 @@ | |
| This module contains unit tests for `train_compile.py`, ensuring that various | ||
| model configurations and parallelism strategies can be successfully compiled | ||
| for different hardware topologies. | ||
|
|
||
| These tests exercise the compilation pipeline only, not numerical correctness, | ||
| so most use `override_model_config=true` with a reduced `base_num_decoder_layers` | ||
| to keep CPU compile times bounded. Full-scale model correctness is covered | ||
| elsewhere. Tests that deliberately keep the full layer count do so to exercise | ||
| sharding/pipeline edge cases and are annotated inline. | ||
| """ | ||
|
|
||
| import unittest | ||
|
|
@@ -373,6 +379,8 @@ def test_moe_dropping_bf16(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=mixtral-8x7b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟢 Using `override_model_config=true` and reducing `base_num_decoder_layers` to 8 is a great optimization for compilation tests. This significantly reduces the graph size and compilation time while still exercising the necessary code paths for MoE models.
|
||
| "sparse_matmul=False", | ||
| "capacity_factor=1", | ||
| "per_device_batch_size=4", | ||
|
|
@@ -420,6 +428,8 @@ def test_moe_megablox_bf16(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=mixtral-8x7b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "sparse_matmul=True", | ||
| "megablox=True", | ||
| "per_device_batch_size=4", | ||
|
|
@@ -442,6 +452,8 @@ def test_moe_megablox_ring_ep_random(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=deepseek3-test", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "sparse_matmul=True", | ||
| "megablox=True", | ||
| "per_device_batch_size=4", | ||
|
|
@@ -466,6 +478,8 @@ def test_moe_ragged_dot_bf16(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=mixtral-8x7b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "sparse_matmul=True", | ||
| "megablox=False", | ||
| "per_device_batch_size=4", | ||
|
|
@@ -488,6 +502,8 @@ def test_moe_dense_bf16(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=mixtral-8x7b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "sparse_matmul=False", | ||
| "capacity_factor=-1", | ||
| "per_device_batch_size=4", | ||
|
|
@@ -534,6 +550,8 @@ def test_moe_pp_bf16(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=2", | ||
| "model_name=mixtral-8x7b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "sparse_matmul=False", | ||
| "capacity_factor=1", | ||
| "per_device_batch_size=4", | ||
|
|
@@ -558,6 +576,8 @@ def test_moe_deepseek_scanned_bf16(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=deepseek3-test", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "sparse_matmul=True", | ||
| "megablox=False", | ||
| "per_device_batch_size=2", | ||
|
|
@@ -606,6 +626,8 @@ def test_moe_deepseek_with_device_limit(self): | |
| "use_iota_embed=true", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=deepseek3-test", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "sparse_matmul=True", | ||
| "megablox=False", | ||
| "per_device_batch_size=1", | ||
|
|
@@ -620,14 +642,16 @@ def test_moe_deepseek_with_device_limit(self): | |
|
|
||
| @pytest.mark.cpu_only | ||
| def test_moe_deepseek_pipeline_subset(self): | ||
| # Keeps the full layer count so pipeline_parallel_layers=56 exercises | ||
| # the real stage boundaries. | ||
| compiled_trainstep_file = "/tmp/test_moe_deepseek_pipeline_subset.pickle" | ||
| train_compile_main( | ||
| ( | ||
| "", | ||
| get_test_config_path(), | ||
| f"compiled_trainstep_file={compiled_trainstep_file}", | ||
| "compile_topology=v5p-64", | ||
| "compile_topology_num_slices=8", | ||
| "compile_topology=v5p-8", | ||
| "compile_topology_num_slices=2", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟢 Reducing the `compile_topology` and `compile_topology_num_slices` here is appropriate for a unit/compilation test. It speeds up the CI by requiring fewer resources and less time to verify that the pipeline parallelism logic compiles correctly.
|
||
| "use_iota_embed=true", | ||
| "model_name=deepseek3-test", | ||
| "megablox=True", | ||
|
|
@@ -636,8 +660,9 @@ def test_moe_deepseek_pipeline_subset(self): | |
| "per_device_batch_size=1", | ||
| "max_target_length=1024", | ||
| "pipeline_parallel_layers=56", | ||
| "ici_expert_parallelism=16", | ||
| "dcn_pipeline_parallelism=8", | ||
| "ici_expert_parallelism=4", | ||
| "ici_fsdp_parallelism=1", | ||
| "dcn_pipeline_parallelism=2", | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -669,22 +694,23 @@ def test_moe_llama4_17b_16e(self): | |
| "", | ||
| get_test_config_path(), | ||
| f"compiled_trainstep_file={compiled_trainstep_file}", | ||
| "compile_topology=v5p-128", | ||
| "compile_topology=v5p-16", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=llama4-17b-16e", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=4", | ||
| "per_device_batch_size=1", | ||
| "max_target_length=1024", | ||
| "dtype=bfloat16", | ||
| "weight_dtype=bfloat16", | ||
| "scan_layers=True", | ||
| "ici_fsdp_parallelism=16", | ||
| "ici_tensor_parallelism=4", | ||
| "ici_fsdp_parallelism=4", | ||
| "ici_tensor_parallelism=2", | ||
| ) | ||
| ) | ||
|
|
||
| @pytest.mark.cpu_only | ||
| def test_moe_gpt_oss_20b_sparse_matmul(self): | ||
| compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle" | ||
| def _run_moe_gpt_oss_20b(self, suffix, matmul_args): | ||
| compiled_trainstep_file = f"/tmp/test_moe_gpt_oss_20b_{suffix}.pickle" | ||
| train_compile_main( | ||
| ( | ||
| "", | ||
|
|
@@ -693,38 +719,25 @@ def test_moe_gpt_oss_20b_sparse_matmul(self): | |
| "compile_topology=v5p-16", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=gpt-oss-20b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "per_device_batch_size=1", | ||
| "max_target_length=1024", | ||
| "dtype=bfloat16", | ||
| "weight_dtype=bfloat16", | ||
| "scan_layers=True", | ||
| "sparse_matmul=True", | ||
| "megablox=True", | ||
| "attention=flash", | ||
| *matmul_args, | ||
| ) | ||
| ) | ||
|
|
||
| @pytest.mark.cpu_only | ||
| def test_moe_gpt_oss_20b_sparse_matmul(self): | ||
| self._run_moe_gpt_oss_20b("sparse_matmul", ["sparse_matmul=True", "megablox=True"]) | ||
|
|
||
| @pytest.mark.cpu_only | ||
| def test_moe_gpt_oss_20b_dense_matmul(self): | ||
| compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_dense_matmul.pickle" | ||
| train_compile_main( | ||
| ( | ||
| "", | ||
| get_test_config_path(), | ||
| f"compiled_trainstep_file={compiled_trainstep_file}", | ||
| "compile_topology=v5p-16", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=gpt-oss-20b", | ||
| "per_device_batch_size=1", | ||
| "max_target_length=1024", | ||
| "dtype=bfloat16", | ||
| "weight_dtype=bfloat16", | ||
| "scan_layers=True", | ||
| "sparse_matmul=False", | ||
| "capacity_factor=-1", | ||
| "attention=flash", | ||
| ) | ||
| ) | ||
| self._run_moe_gpt_oss_20b("dense_matmul", ["sparse_matmul=False", "capacity_factor=-1"]) | ||
|
|
||
| @pytest.mark.cpu_only | ||
| def test_gpt3_6b(self): | ||
|
|
@@ -769,6 +782,8 @@ def test_qwen3_next(self): | |
| "compile_topology=v5p-256", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=qwen3-next-80b-a3b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "per_device_batch_size=1", | ||
| "max_target_length=1024", | ||
| ) | ||
|
|
@@ -867,6 +882,8 @@ def test_olmo3_7b(self): | |
| "compile_topology=v5p-8", | ||
| "compile_topology_num_slices=1", | ||
| "model_name=olmo3-7b", | ||
| "override_model_config=true", | ||
| "base_num_decoder_layers=8", | ||
| "per_device_batch_size=1", | ||
| "scan_layers=True", | ||
| "max_target_length=1024", | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.