Skip to content

Commit 5a44af0

Browse files
gagikaNuojCheng
authored andcommitted
Enable grain input pipeline save and restore for distillation.
simple fix on debug sharding log add all gather insertion per repeat working all gather insertion clean version fsdp+pp bug free add bsw checkpoint split bsw all gather into two add custom vjp
1 parent 305e4e0 commit 5a44af0

5 files changed

Lines changed: 425 additions & 158 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights
295295
# It may be useful to do the reverse when the layers_per_stage is very large.
296296
# The below settings only have effect when using pipeline parallelism.
297297
scan_pipeline_iterations: True
298+
scan_pipeline_repeats: True
298299
scan_layers_per_stage: False
299300
set_remat_policy_on_pipeline_iterations: True
300301
set_remat_policy_on_layers_per_stage: False
@@ -911,7 +912,7 @@ xprof_e2e_enable_fw_throttle_event: False
911912
xprof_e2e_enable_fw_power_level_event: False
912913
xprof_e2e_enable_fw_thermal_event: False
913914

914-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
915+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
915916
debug_sharding: False # Prints model weights sharding info
916917

917918
# Checkpoint Structured logging

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ class PipelineParallelism(BaseModel):
852852
)
853853
pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.")
854854
scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.")
855+
scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.")
855856
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
856857
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
857858
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")

src/maxtext/layers/decoders.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -793,12 +793,9 @@ def __call__(
793793
model_mode,
794794
)
795795
if cfg.using_pipeline_parallelism:
796-
if cfg.pipeline_fsdp_ag_once:
797-
logical_partition_spec = self.pipeline_module.get_weight_sharding(
798-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
799-
)
800-
else:
801-
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
796+
logical_partition_spec = self.pipeline_module.get_weight_sharding(
797+
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
798+
)
802799
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
803800
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
804801
dense_layer = RemattedBlockLayers[0]
@@ -1035,6 +1032,13 @@ def __call__(
10351032

10361033
else:
10371034
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
1035+
logits = sharding.maybe_shard_with_logical(
1036+
logits,
1037+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1038+
mesh=self.mesh,
1039+
shard_mode=self.config.shard_mode,
1040+
debug_sharding=self.config.debug_sharding,
1041+
)
10381042

10391043
# The API of the Decoder is now a tuple, providing both the main output
10401044
# and the raw hidden state needed for auxiliary tasks.

0 commit comments

Comments
 (0)