Skip to content

Commit 0963f05

Browse files
Merge pull request #3071 from AI-Hypercomputer:chengnuojin-pp-separate-weights
PiperOrigin-RevId: 882320546
2 parents 0fe1adf + 721bb5a commit 0963f05

9 files changed

Lines changed: 2234 additions & 409 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ pipeline_parallel_layers: -1 # Pipeline only this number of layers - for the rem
275275
# PP degree divides the number of layers.
276276
# By default (when set to -1) we pipeline all of the decoder layers.
277277

278-
279278
# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages.
280279
# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices
281280
num_pipeline_microbatches: -1
@@ -288,8 +287,9 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights
288287
# to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed
289288
# to every microbatch. This is similar to zero-1 sharding, since we also don't need to all gather the FSDP weights in the backward pass.
290289
# An alternative to setting this to true may be to replace any FSDP with DP and use optimizer offloading if necessary.
291-
# A more optimal behavior is to all-gather at the start of each repeat, which would ideally get the best of both worlds -
292-
# a small amount of memory and time, however this has proven hard to implement in SPMD, see b/364386697 for more.
290+
pipeline_fsdp_ag_per_repeat: False
291+
# Pipeline weight prefetching per repeat is an advanced SPMD pipeline parallelism improvement technique
292+
# When enabled, it prefetches necessary weight gathering ahead of microbatched computation, therefore reducing collectives
293293

294294
# There are two loops for PP:
295295
# 1) Outer loop over microbatches (pipeline iterations)
@@ -299,6 +299,7 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights
299299
# It may be useful to do the reverse when the layers_per_stage is very large.
300300
# The below settings only have effect when using pipeline parallelism.
301301
scan_pipeline_iterations: True
302+
scan_pipeline_repeats: True
302303
scan_layers_per_stage: False
303304
set_remat_policy_on_pipeline_iterations: True
304305
set_remat_policy_on_layers_per_stage: False

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ rope_truncate: True
5656
rope_attention_scaling: False
5757

5858
override_logical_axis_rules: True
59-
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60-
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
59+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63-
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6464
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6565
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
6666
['activation_norm_length', ['context']],
6767
['activation_heads', []],
68+
['activation_stage', 'stage'],
6869
['embed', ['fsdp']],
6970
['embed_no_exp', ['fsdp']],
7071
['q_lora', ['fsdp']],
7172
['kv_lora', ['fsdp']],
73+
['layers', 'stage'],
7274
['q_lora_up_proj', ['fsdp_transpose', 'expert']],
7375
['kv_lora_up_proj', ['fsdp_transpose', 'expert']],
7476
['q_heads', ['fsdp_transpose', 'expert']],

src/maxtext/configs/types.py

Lines changed: 54 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,9 @@ class IciParallelism(BaseModel):
842842
class PipelineParallelism(BaseModel):
843843
"""Configuration for pipeline parallelism."""
844844

845+
pipeline_fsdp_ag_per_repeat: bool = Field(
846+
False, description="Enable weight prefetching for circular pipeline parallelism."
847+
)
845848
num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.")
846849
num_pipeline_repeats: int = Field(
847850
-1,
@@ -857,6 +860,7 @@ class PipelineParallelism(BaseModel):
857860
)
858861
pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.")
859862
scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.")
863+
scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.")
860864
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
861865
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
862866
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")
@@ -2250,6 +2254,17 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22502254
)
22512255
self.num_pipeline_repeats = num_pipeline_repeats
22522256

2257+
if self.pipeline_fsdp_ag_per_repeat:
2258+
assert self.num_pipeline_repeats > 1, "Pipeline weight prefetching only supports circular pipeline."
2259+
assert (
2260+
self.num_layers_per_pipeline_stage == 1
2261+
), "Pipeline weight prefetching currently only supports one layer per pipeline stage."
2262+
assert (
2263+
not self.pipeline_delay_activation_forwarding
2264+
), "Pipeline weight prefetching does not support pipeline delay."
2265+
assert not self.quantization, "Quantization is currently not supported for pipeline prefetching."
2266+
assert not self.scan_layers_per_stage, "Pipeline weight prefetching currently does not support scan."
2267+
22532268
assert (num_stages * self.num_pipeline_repeats * self.num_layers_per_pipeline_stage) == (
22542269
self.pipeline_parallel_layers
22552270
), (
@@ -2539,78 +2554,45 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25392554
raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.")
25402555

25412556
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
2542-
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
2543-
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2544-
self.ici_parallelism = [
2545-
self.ici_diloco_parallelism,
2546-
self.ici_pipeline_parallelism,
2547-
self.ici_data_parallelism,
2548-
self.ici_fsdp_parallelism,
2549-
self.ici_fsdp_transpose_parallelism,
2550-
self.ici_sequence_parallelism,
2551-
self.ici_context_parallelism,
2552-
self.ici_context_autoregressive_parallelism,
2553-
self.ici_tensor_parallelism,
2554-
self.ici_tensor_transpose_parallelism,
2555-
self.ici_tensor_sequence_parallelism,
2556-
self.ici_expert_parallelism,
2557-
self.ici_autoregressive_parallelism,
2558-
]
2559-
self.dcn_parallelism = [
2560-
self.dcn_diloco_parallelism,
2561-
self.dcn_pipeline_parallelism,
2562-
self.dcn_data_parallelism,
2563-
self.dcn_fsdp_parallelism,
2564-
self.dcn_fsdp_transpose_parallelism,
2565-
self.dcn_sequence_parallelism,
2566-
self.dcn_context_parallelism,
2567-
self.dcn_context_autoregressive_parallelism,
2568-
self.dcn_tensor_parallelism,
2569-
self.dcn_tensor_transpose_parallelism,
2570-
self.dcn_tensor_sequence_parallelism,
2571-
self.dcn_expert_parallelism,
2572-
self.dcn_autoregressive_parallelism,
2573-
]
2574-
else:
2575-
ici_map = {
2576-
"diloco": self.ici_diloco_parallelism,
2577-
"data": self.ici_data_parallelism,
2578-
"stage": self.ici_pipeline_parallelism,
2579-
"fsdp": self.ici_fsdp_parallelism,
2580-
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2581-
"sequence": self.ici_sequence_parallelism,
2582-
"context": self.ici_context_parallelism,
2583-
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2584-
"tensor": self.ici_tensor_parallelism,
2585-
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2586-
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2587-
"model": self.ici_tensor_parallelism,
2588-
"expert": self.ici_expert_parallelism,
2589-
"autoregressive": self.ici_autoregressive_parallelism,
2590-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2591-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2592-
}
2593-
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2594-
2595-
dcn_map = {
2596-
"diloco": self.dcn_diloco_parallelism,
2597-
"data": self.dcn_data_parallelism,
2598-
"stage": self.dcn_pipeline_parallelism,
2599-
"fsdp": self.dcn_fsdp_parallelism,
2600-
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2601-
"sequence": self.dcn_sequence_parallelism,
2602-
"context": self.dcn_context_parallelism,
2603-
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2604-
"tensor": self.dcn_tensor_parallelism,
2605-
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2606-
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2607-
"model": self.dcn_tensor_parallelism,
2608-
"expert": self.dcn_expert_parallelism,
2609-
"autoregressive": self.dcn_autoregressive_parallelism,
2610-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2611-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2612-
}
2613-
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
2557+
ici_map = {
2558+
"diloco": self.ici_diloco_parallelism,
2559+
"data": self.ici_data_parallelism,
2560+
"stage": self.ici_pipeline_parallelism,
2561+
"fsdp": self.ici_fsdp_parallelism,
2562+
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2563+
"sequence": self.ici_sequence_parallelism,
2564+
"context": self.ici_context_parallelism,
2565+
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2566+
"tensor": self.ici_tensor_parallelism,
2567+
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2568+
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2569+
"model": self.ici_tensor_parallelism,
2570+
"expert": self.ici_expert_parallelism,
2571+
"autoregressive": self.ici_autoregressive_parallelism,
2572+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2573+
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2574+
}
2575+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2576+
2577+
dcn_map = {
2578+
"diloco": self.dcn_diloco_parallelism,
2579+
"data": self.dcn_data_parallelism,
2580+
"stage": self.dcn_pipeline_parallelism,
2581+
"fsdp": self.dcn_fsdp_parallelism,
2582+
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2583+
"sequence": self.dcn_sequence_parallelism,
2584+
"context": self.dcn_context_parallelism,
2585+
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2586+
"tensor": self.dcn_tensor_parallelism,
2587+
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2588+
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2589+
"model": self.dcn_tensor_parallelism,
2590+
"expert": self.dcn_expert_parallelism,
2591+
"autoregressive": self.dcn_autoregressive_parallelism,
2592+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2593+
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2594+
}
2595+
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
26142596

26152597
# Diloco params
26162598
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

src/maxtext/layers/decoders.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def setup(self):
307307
if self.config.using_pipeline_parallelism:
308308
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
309309
remat_policy = self.get_remat_policy()
310-
self.pipeline_module = pipeline.Pipeline(
310+
self.pipeline_module = pipeline.create_pipeline(
311311
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
312312
)
313313

@@ -794,12 +794,11 @@ def __call__(
794794
model_mode,
795795
)
796796
if cfg.using_pipeline_parallelism:
797-
if cfg.pipeline_fsdp_ag_once:
798-
logical_partition_spec = self.pipeline_module.get_weight_sharding(
799-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
800-
)
801-
else:
802-
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
797+
logical_partition_spec = (
798+
self.pipeline_module.get_weight_sharding(y, decoder_segment_ids, decoder_positions, deterministic, model_mode)
799+
if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat
800+
else None
801+
)
803802
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
804803
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
805804
dense_layer = RemattedBlockLayers[0]

0 commit comments

Comments
 (0)