A More Efficient SPMD Pipeline Parallelism for Large-scale Training#3071
A More Efficient SPMD Pipeline Parallelism for Large-scale Training#3071copybara-service[bot] merged 9 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
6c22238 to
28f98ff
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
28f98ff to
64b37ff
Compare
a7d38d0 to
9a36099
Compare
d05c015 to
e521a58
Compare
51e6713 to
286e066
Compare
0efb56e to
d7394fe
Compare
d70484b to
7d31fb9
Compare
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
| cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec) | ||
| nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1, physical_partition_spec) |
There was a problem hiding this comment.
\n🔴 Critical bug: loop_iteration + 1 does not fetch the next repeat for all stages when num_stages > 2. Due to the pipeline bubble, stages are staggered. Using loop_iteration + 1 will fetch the next repeat only for the first stage or two, causing the later stages to fetch weights from the previous repeat, leading to incorrect computation. To fetch the next repeat for all stages, use loop_iteration + (self.num_stages - 1) * self.forwarding_delay. Similarly, to ensure cur_repeat_weights contains the current repeat for all stages, use loop_iteration - 1.\n\nsuggestion\n cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration - 1, physical_partition_spec)\n nxt_repeat_weights = self.from_all_variables_to_repeat_weights(\n weights, loop_iteration + (self.num_stages - 1) * self.forwarding_delay, physical_partition_spec\n )\n\n
| ): | ||
| """Generates the buffer sliding window (bsw) from the gathered repeat weights.""" | ||
| bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec) | ||
| repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) |
There was a problem hiding this comment.
\n🟠 High severity: If a leaf in physical_partition_spec is None (which represents unconstrained or fully replicated sharding in JAX), p[1:] will throw a TypeError: 'NoneType' object is not subscriptable. Ensure p is actually a PartitionSpec or a tuple before slicing.\n\nsuggestion\n repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]) if isinstance(p, tuple) else p, physical_partition_spec)\n\n
| # Drop the first dimension (usually the 'stage' or 'layer' axis handled by the scan) | ||
| return P(*processed_pps[1:]) |
There was a problem hiding this comment.
\n🟠 High severity: remove_gathered_mesh_axes can return None if pps is None (which JAX uses for unpartitioned/replicated arrays). Slicing processed_pps[1:] when it's None will crash with a TypeError. Handle the case where processed_pps is not a PartitionSpec or tuple.\n\nsuggestion\n # Drop the first dimension (usually the 'stage' or 'layer' axis handled by the scan)\n if isinstance(processed_pps, tuple):\n return P(*processed_pps[1:])\n return processed_pps\n\n
7d31fb9 to
a205197
Compare
simple fix on debug sharding log add all gather insertion per repeat working all gather insertion clean version fsdp+pp bug free add bsw checkpoint split bsw all gather into two add custom vjp
suexu1025
left a comment
There was a problem hiding this comment.
a few comments for tokamax, other wise it looks good
Description
This pull request introduces a significantly refactored and more efficient implementation of SPMD Pipeline Parallelism (PP) in MaxText, specifically designed to optimize large-scale training such as DeepSeek-V3. The core of this update is the introduction of a
CircularPipelinethat utilizes a "Buffer Sliding Window" (BSW) mechanism to manage weights during pipeline repeats, reducing memory overhead and improving computational efficiency.Key Changes
1. Refactored Pipeline Architecture (
src/maxtext/layers/pipeline.py)PipelineBase: Introduced a base class to house shared logic for different pipeline implementations.CircularPipeline: A new implementation optimized for circular pipelining using BSW. It employsjax.lax.scanover pipeline repeats and iterations to minimize the overhead of weight handling.create_pipeline: A factory function that selects the appropriate pipeline module (either the originalPipelineor the newCircularPipeline) based on the provided configuration.src/maxtext/layers/pipeline_deprecated.py.2. Buffer Sliding Window (BSW) & Utilities (
src/maxtext/utils/pipeline_utils.py)all_gather_invariant.create_scanned_functionwhich usesjax.custom_vjpto optimize the backward pass of the scanned pipeline iterations, managing gradients for weights and BSW states more effectively.3. Model & Configuration Updates
deepseek3-671b-2dfsdp.ymlto include thestageaxis in the mesh and data sharding rules.src/maxtext/models/deepseek_batchsplit.pyto supportgmmoperations when pipeline parallelism is enabled.scan_pipeline_repeatsto control whether to scan over repeats.4. Testing & Validation
tests/unit/train_compile_test.pyto verify the compilation of the circular pipeline with DeepSeek-V3 across various parallelism strategies (FSDP, TP, EP).pipeline.create_pipelinefactory.Performance Improvements
The refactoring, particularly the BSW and the use of scanned iterations with custom VJPs, is intended to provide a more efficient execution path for large-scale models by overlapping weight gathering with computation and reducing the memory pressure typically associated with complex pipeline schedules.
Tests
TPU-VM test between main branch and new PR
webdiff
2 V5p-8
webdiff
New implementation correctness (losses match)
webdiff
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.