Skip to content

Commit 062a066

Browse files
committed
add all gather insertion per repeat
1 parent 8434e35 commit 062a066

4 files changed

Lines changed: 248 additions & 102 deletions

File tree

src/MaxText/layers/decoders.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -741,12 +741,9 @@ def __call__(
741741
model_mode,
742742
)
743743
if cfg.using_pipeline_parallelism:
744-
if cfg.pipeline_fsdp_ag_once:
745-
logical_partition_spec = self.pipeline_module.get_weight_sharding(
746-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
747-
)
748-
else:
749-
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
744+
logical_partition_spec = self.pipeline_module.get_weight_sharding(
745+
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
746+
)
750747
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
751748
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
752749
dense_layer = RemattedBlockLayers[0]

0 commit comments

Comments
 (0)