Skip to content

Commit 721bb5a

Browse files
committed
refactor pr
1 parent c231424 commit 721bb5a

8 files changed

Lines changed: 384 additions & 333 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,6 @@ pipeline_parallel_layers: -1 # Pipeline only this number of layers - for the rem
275275
# PP degree divides the number of layers.
276276
# By default (when set to -1) we pipeline all of the decoder layers.
277277

278-
# Pipeline weight prefetching is an advanced SPMD pipeline parallelism improvement technique
279-
# When enabled, it prefetches necessary weight gathering ahead of microbatched computation, therefore reducing collectives
280-
use_pipeline_weight_prefetching: False
281-
282278
# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages.
283279
# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices
284280
num_pipeline_microbatches: -1
@@ -291,8 +287,9 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights
291287
# to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed
292288
# to every microbatch. This is similar to zero-1 sharding, since we also don't need to all gather the FSDP weights in the backward pass.
293289
# An alternative to setting this to true may be to replace any FSDP with DP and use optimizer offloading if necessary.
294-
# A more optimal behavior is to all-gather at the start of each repeat, which would ideally get the best of both worlds -
295-
# a small amount of memory and time, however this has proven hard to implement in SPMD, see b/364386697 for more.
290+
pipeline_fsdp_ag_per_repeat: False
291+
# Pipeline weight prefetching per repeat is an advanced SPMD pipeline parallelism improvement technique
292+
# When enabled, it prefetches necessary weight gathering ahead of microbatched computation, therefore reducing collectives
296293

297294
# There are two loops for PP:
298295
# 1) Outer loop over microbatches (pipeline iterations)

src/maxtext/configs/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ class IciParallelism(BaseModel):
840840
class PipelineParallelism(BaseModel):
841841
"""Configuration for pipeline parallelism."""
842842

843-
use_pipeline_weight_prefetching: bool = Field(
843+
pipeline_fsdp_ag_per_repeat: bool = Field(
844844
False, description="Enable weight prefetching for circular pipeline parallelism."
845845
)
846846
num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.")
@@ -2240,7 +2240,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22402240
)
22412241
self.num_pipeline_repeats = num_pipeline_repeats
22422242

2243-
if self.use_pipeline_weight_prefetching:
2243+
if self.pipeline_fsdp_ag_per_repeat:
22442244
assert self.num_pipeline_repeats > 1, "Pipeline weight prefetching only supports circular pipeline."
22452245
assert (
22462246
self.num_layers_per_pipeline_stage == 1
@@ -2556,7 +2556,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25562556
"expert": self.ici_expert_parallelism,
25572557
"autoregressive": self.ici_autoregressive_parallelism,
25582558
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2559-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2559+
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
25602560
}
25612561
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
25622562

@@ -2576,7 +2576,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25762576
"expert": self.dcn_expert_parallelism,
25772577
"autoregressive": self.dcn_autoregressive_parallelism,
25782578
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2579-
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2579+
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
25802580
}
25812581
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
25822582

src/maxtext/layers/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def __call__(
796796
if cfg.using_pipeline_parallelism:
797797
logical_partition_spec = (
798798
self.pipeline_module.get_weight_sharding(y, decoder_segment_ids, decoder_positions, deterministic, model_mode)
799-
if cfg.pipeline_fsdp_ag_once or cfg.use_pipeline_weight_prefetching
799+
if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat
800800
else None
801801
)
802802
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:

src/maxtext/layers/pipeline.py

Lines changed: 195 additions & 139 deletions
Large diffs are not rendered by default.

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def gmm(
809809
group_sizes,
810810
representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
811811
)
812-
if config.use_qwix_quantization or (config.using_pipeline_parallelism and config.use_pipeline_weight_prefetching):
812+
if config.use_qwix_quantization:
813813
output = megablox.gmm(
814814
lhs=inputs,
815815
rhs=kernel,

0 commit comments

Comments
 (0)