Skip to content

A More Efficient SPMD Pipeline Parallelism for Large-scale Training#3071

Merged
copybara-service[bot] merged 9 commits intomainfrom
chengnuojin-pp-separate-weights
Mar 12, 2026
Merged

A More Efficient SPMD Pipeline Parallelism for Large-scale Training#3071
copybara-service[bot] merged 9 commits intomainfrom
chengnuojin-pp-separate-weights

Conversation

@NuojCheng
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng commented Feb 3, 2026

Description

This pull request introduces a significantly refactored and more efficient implementation of SPMD Pipeline Parallelism (PP) in MaxText, specifically designed to optimize large-scale training such as DeepSeek-V3. The core of this update is the introduction of a CircularPipeline that utilizes a "Buffer Sliding Window" (BSW) mechanism to manage weights during pipeline repeats, reducing memory overhead and improving computational efficiency.

Key Changes

1. Refactored Pipeline Architecture (src/maxtext/layers/pipeline.py)

  • PipelineBase: Introduced a base class to house shared logic for different pipeline implementations.
  • CircularPipeline: A new implementation optimized for circular pipelining using BSW. It employs jax.lax.scan over pipeline repeats and iterations to minimize the overhead of weight handling.
  • create_pipeline: A factory function that selects the appropriate pipeline module (either the original Pipeline or the new CircularPipeline) based on the provided configuration.
  • Deprecated Implementation: The original pipeline implementation has been moved to src/maxtext/layers/pipeline_deprecated.py.

2. Buffer Sliding Window (BSW) & Utilities (src/maxtext/utils/pipeline_utils.py)

  • Introduced a new utility module for pipeline-specific operations.
  • Implemented logic for gathering FSDP-partitioned variables across pipeline stages and repeats using all_gather_invariant.
  • Custom VJP Logic: Added create_scanned_function which uses jax.custom_vjp to optimize the backward pass of the scanned pipeline iterations, managing gradients for weights and BSW states more effectively.

3. Model & Configuration Updates

  • DeepSeek-V3 Integration: Updated deepseek3-671b-2dfsdp.yml to include the stage axis in the mesh and data sharding rules.
  • Batch Splitting: Modified src/maxtext/models/deepseek_batchsplit.py to support gmm operations when pipeline parallelism is enabled.
  • New Config Options: Added scan_pipeline_repeats to control whether to scan over repeats.

4. Testing & Validation

  • AOT Compile Tests: Added several new tests in tests/unit/train_compile_test.py to verify the compilation of the circular pipeline with DeepSeek-V3 across various parallelism strategies (FSDP, TP, EP).
  • Updated existing unit tests to use the new pipeline.create_pipeline factory.

Performance Improvements

The refactoring, particularly the BSW and the use of scanned iterations with custom VJPs, is intended to provide a more efficient execution path for large-scale models by overlapping weight gathering with computation and reducing the memory pressure typically associated with complex pipeline schedules.

Tests

TPU-VM test between main branch and new PR

  1. V5p-8
smoke_train ici_pipeline_parallelism=2 

webdiff

2 V5p-8

smoke_train ici_pipeline_parallelism=2 pipeline_fsdp_ag_once=true

webdiff

New implementation correctness (losses match)

smoke_train ici_pipeline_parallelism=2

webdiff

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 3, 2026

Codecov Report

❌ Patch coverage is 88.88889% with 13 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/pipeline.py 88.88% 7 Missing and 6 partials ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 2 times, most recently from 6c22238 to 28f98ff Compare February 9, 2026 19:29
@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 9, 2026

Codecov Report

❌ Patch coverage is 81.00962% with 79 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/pipeline.py 82.38% 41 Missing and 15 partials ⚠️
src/maxtext/utils/pipeline_utils.py 76.04% 19 Missing and 4 partials ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng added pull ready draft Draft PR and removed pull ready labels Feb 9, 2026
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch from 28f98ff to 64b37ff Compare February 9, 2026 22:23
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 8 times, most recently from a7d38d0 to 9a36099 Compare February 18, 2026 17:16
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 2 times, most recently from d05c015 to e521a58 Compare February 24, 2026 23:14
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 2 times, most recently from 51e6713 to 286e066 Compare February 26, 2026 00:16
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 3 times, most recently from 0efb56e to d7394fe Compare March 3, 2026 01:28
Comment thread src/maxtext/layers/pipeline.py Outdated
Comment thread tests/unit/train_compile_test.py
Comment thread tests/unit/train_compile_test.py Outdated
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 3 times, most recently from d70484b to 7d31fb9 Compare March 6, 2026 21:49
@NuojCheng NuojCheng added gemini-review and removed draft Draft PR labels Mar 6, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 6, 2026

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 6, 2026

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\n## 📋 Review Summary\n\nThis pull request introduces an advanced pipeline parallelism refactoring by splitting out `PipelineBase`, a new `CircularPipeline` module with weight prefetching, and updating `Pipeline`. The changes heavily improve large-scale training efficiency by hiding weight gathering collectives behind microbatched computation.\n\n## 🔍 General Feedback\n\n- **Positive Highlights**: The use of custom VJPs with `flax.linen.scan` is incredibly elegant and an advanced technique to avoid accumulating memory during microbatch looping while manually applying gradients. The abstraction into `PipelineBase` cleanly separates the shared orchestration logic from the execution implementations.\n- **Critical Logic Errors**: The weight prefetching algorithm incorrectly computed `nxt_repeat_weights` via `loop_iteration + 1`, which does not correctly span the microbatch staggered pipeline bubble for `num_stages > 2`, leading to incorrect execution by falling back to the previous repeat.\n- **Python Edge Cases**: Certain assumptions about PyTrees containing `PartitionSpec` objects everywhere can crash natively on `NoneType` values representing unpartitioned data, which is heavily prominent in JAX. Slicing on these directly without type checking has been flagged.\n- **Cleanup**: Consider removing the newly added `pipeline_deprecated.py` file completely if it remains unused and is just a copy of the old implementation, as it incurs technical debt.\n

Comment thread src/maxtext/layers/pipeline.py Outdated
Comment on lines +1205 to +1206
cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec)
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1, physical_partition_spec)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\n🔴 Critical bug: loop_iteration + 1 does not fetch the next repeat for all stages when num_stages > 2. Due to the pipeline bubble, stages are staggered. Using loop_iteration + 1 will fetch the next repeat only for the first stage or two, causing the later stages to fetch weights from the previous repeat, leading to incorrect computation. To fetch the next repeat for all stages, use loop_iteration + (self.num_stages - 1) * self.forwarding_delay. Similarly, to ensure cur_repeat_weights contains the current repeat for all stages, use loop_iteration - 1.\n\nsuggestion\n cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration - 1, physical_partition_spec)\n nxt_repeat_weights = self.from_all_variables_to_repeat_weights(\n weights, loop_iteration + (self.num_stages - 1) * self.forwarding_delay, physical_partition_spec\n )\n\n

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong comment

Comment thread src/maxtext/layers/pipeline.py Outdated
):
"""Generates the buffer sliding window (bsw) from the gathered repeat weights."""
bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec)
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\n🟠 High severity: If a leaf in physical_partition_spec is None (which represents unconstrained or fully replicated sharding in JAX), p[1:] will throw a TypeError: 'NoneType' object is not subscriptable. Ensure p is actually a PartitionSpec or a tuple before slicing.\n\nsuggestion\n repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]) if isinstance(p, tuple) else p, physical_partition_spec)\n\n

Comment on lines +81 to +82
# Drop the first dimension (usually the 'stage' or 'layer' axis handled by the scan)
return P(*processed_pps[1:])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\n🟠 High severity: remove_gathered_mesh_axes can return None if pps is None (which JAX uses for unpartitioned/replicated arrays). Slicing processed_pps[1:] when it's None will crash with a TypeError. Handle the case where processed_pps is not a PartitionSpec or tuple.\n\nsuggestion\n # Drop the first dimension (usually the 'stage' or 'layer' axis handled by the scan)\n if isinstance(processed_pps, tuple):\n return P(*processed_pps[1:])\n return processed_pps\n\n

Comment thread tests/unit/pipeline_parallelism_test.py Outdated
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch from 7d31fb9 to a205197 Compare March 7, 2026 00:37
Comment thread src/maxtext/configs/base.yml Outdated
Comment thread src/maxtext/layers/moe.py Outdated
Comment thread src/maxtext/models/deepseek_batchsplit.py Outdated
Copy link
Copy Markdown
Collaborator

@suexu1025 suexu1025 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few comments for tokamax, other wise it looks good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants