Skip to content

Commit e4d05fc

Browse files
gagikaNuojCheng
authored andcommitted
Enable grain input pipeline save and restore for distillation.
simple fix on debug sharding log add all gather insertion per repeat working all gather insertion clean version fsdp+pp bug free add bsw checkpoint split bsw all gather into two add custom vjp
1 parent 0a5cce3 commit e4d05fc

5 files changed

Lines changed: 426 additions & 158 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights
290290
# It may be useful to do the reverse when the layers_per_stage is very large.
291291
# The below settings only have effect when using pipeline parallelism.
292292
scan_pipeline_iterations: True
293+
scan_pipeline_repeats: True
293294
scan_layers_per_stage: False
294295
set_remat_policy_on_pipeline_iterations: True
295296
set_remat_policy_on_layers_per_stage: False
@@ -900,7 +901,7 @@ xprof_e2e_enable_fw_throttle_event: False
900901
xprof_e2e_enable_fw_power_level_event: False
901902
xprof_e2e_enable_fw_thermal_event: False
902903

903-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
904+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
904905
debug_sharding: False # Prints model weights sharding info
905906

906907
# Checkpoint Structured logging

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ class PipelineParallelism(BaseModel):
842842
)
843843
pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.")
844844
scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.")
845+
scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.")
845846
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
846847
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
847848
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")

src/maxtext/layers/decoders.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,12 +745,9 @@ def __call__(
745745
model_mode,
746746
)
747747
if cfg.using_pipeline_parallelism:
748-
if cfg.pipeline_fsdp_ag_once:
749-
logical_partition_spec = self.pipeline_module.get_weight_sharding(
750-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
751-
)
752-
else:
753-
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
748+
logical_partition_spec = self.pipeline_module.get_weight_sharding(
749+
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
750+
)
754751
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
755752
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
756753
dense_layer = RemattedBlockLayers[0]
@@ -954,6 +951,13 @@ def __call__(
954951

955952
else:
956953
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
954+
logits = sharding.maybe_shard_with_logical(
955+
logits,
956+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
957+
mesh=self.mesh,
958+
shard_mode=self.config.shard_mode,
959+
debug_sharding=self.config.debug_sharding,
960+
)
957961

958962
# The API of the Decoder is now a tuple, providing both the main output
959963
# and the raw hidden state needed for auxiliary tasks.

0 commit comments

Comments
 (0)