Skip to content

Commit 0a64711

Browse files
gagikaNuojCheng
authored andcommitted
Enable grain input pipeline save and restore for distillation.
simple fix on debug sharding log add all gather insertion per repeat working all gather insertion clean version fsdp+pp bug free add bsw checkpoint split bsw all gather into two add custom vjp
1 parent 42ec065 commit 0a64711

5 files changed

Lines changed: 425 additions & 158 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights
290290
# It may be useful to do the reverse when the layers_per_stage is very large.
291291
# The below settings only have effect when using pipeline parallelism.
292292
scan_pipeline_iterations: True
293+
scan_pipeline_repeats: True
293294
scan_layers_per_stage: False
294295
set_remat_policy_on_pipeline_iterations: True
295296
set_remat_policy_on_layers_per_stage: False
@@ -900,7 +901,7 @@ xprof_e2e_enable_fw_throttle_event: False
900901
xprof_e2e_enable_fw_power_level_event: False
901902
xprof_e2e_enable_fw_thermal_event: False
902903

903-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
904+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
904905
debug_sharding: False # Prints model weights sharding info
905906

906907
# Checkpoint Structured logging

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,7 @@ class PipelineParallelism(BaseModel):
841841
)
842842
pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.")
843843
scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.")
844+
scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.")
844845
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
845846
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
846847
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")

src/maxtext/layers/decoders.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -782,12 +782,9 @@ def __call__(
782782
model_mode,
783783
)
784784
if cfg.using_pipeline_parallelism:
785-
if cfg.pipeline_fsdp_ag_once:
786-
logical_partition_spec = self.pipeline_module.get_weight_sharding(
787-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
788-
)
789-
else:
790-
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
785+
logical_partition_spec = self.pipeline_module.get_weight_sharding(
786+
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
787+
)
791788
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
792789
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
793790
dense_layer = RemattedBlockLayers[0]
@@ -997,6 +994,13 @@ def __call__(
997994

998995
else:
999996
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
997+
logits = sharding.maybe_shard_with_logical(
998+
logits,
999+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1000+
mesh=self.mesh,
1001+
shard_mode=self.config.shard_mode,
1002+
debug_sharding=self.config.debug_sharding,
1003+
)
10001004

10011005
# The API of the Decoder is now a tuple, providing both the main output
10021006
# and the raw hidden state needed for auxiliary tasks.

0 commit comments

Comments
 (0)