Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ pipeline_fsdp_ag_per_repeat: False
# It may be useful to do the reverse when the layers_per_stage is very large.
# The below settings only have effect when using pipeline parallelism.
scan_pipeline_iterations: True
scan_pipeline_repeats: True
scan_pipeline_repeats: False
scan_layers_per_stage: False
set_remat_policy_on_pipeline_iterations: True
set_remat_policy_on_layers_per_stage: False
Expand Down
21 changes: 11 additions & 10 deletions src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,49 @@
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
# second, it may be required for DCN communication.
#
# The `context` axis is used for supporting fractional per device batch size
#
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
# store prefetched weights.
mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']
data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp']],
['activation_batch_no_exp_moe', ['data', 'fsdp']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
['activation_heads', ['tensor']],
['activation_kv_heads', ['tensor']],
['activation_length', ['expert']],
['activation_attn_length', ['expert']],
['activation_q_length', ['expert']],
['activation_length', ['context', 'expert']],
['activation_attn_length', ['context', 'expert']],
['activation_q_length', ['context', 'expert']],
['activation_attn_embed', ['tensor']],
['activation_norm_length', ['context']],
['activation_norm_length_moe', ['context']],
['activation_embed', ['tensor']],
['activation_embed_moe', ['tensor']],
['activation_mlp', ['tensor']],
['activation_kv', ['tensor']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
['activation_kv_batch', ['data', 'fsdp', 'expert']],
['activation_kv_batch_no_exp', ['data', 'fsdp']],
['activation_kv_head_dim', ['tensor']],
['activation_vocab', ['tensor']],
['activation_stage', 'stage'],
['activation_exp', ['expert']],
['decode_batch', ['data', 'fsdp', 'expert']],
['mlp', ['tensor']],
['mlp_no_fsdp', ['tensor']],
['vocab', ['tensor']],
['heads', ['tensor']],
['q_heads', ['tensor']],
['kv_heads', ['tensor']],
['embed', ['fsdp', 'expert']],
['embed', ['fsdp', 'expert']], # remove context from embed sharding
['embed_moe', ['fsdp', 'expert']],
['embed_no_exp', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['embed_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['norm', ['tensor']],
Expand Down
68 changes: 34 additions & 34 deletions src/maxtext/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
non-circular"""
# Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0)
microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage")))
microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
return microbatch_ids, repeat_ids
Expand Down Expand Up @@ -1006,8 +1007,12 @@ def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim):

def _gather_one(x, i):
idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim))
replicated_sharding = NamedSharding(self.mesh, P())
return x.at[idx].get(out_sharding=replicated_sharding)
positions_sharding = (
Comment thread
NuojCheng marked this conversation as resolved.
create_sharding(self.mesh, (None, "layers", "activation_length"))
if self.config.shard_mode == ShardMode.EXPLICIT
else None
)
return x.at[idx].get(out_sharding=positions_sharding)

return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)

Expand Down Expand Up @@ -1168,11 +1173,12 @@ def from_repeat_weights_to_bsw(
self,
repeat_weights,
physical_partition_spec,
axes_to_gather=("fsdp", "fsdp_transpose", "expert"), # three major FSDP-like axes
axes_to_gather=("fsdp", "fsdp_transpose", "context", "expert"),
# TODO (chengnuojin) set use_shardmap=true after JAX >= 10.0.0 and use all_gather(..., to='invarying')
use_shardmap=False, # using shardmap produces additional reduce-scatter in backward pass
):
"""Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
axes_to_remove = ["fsdp", "fsdp_transpose"]
axes_to_remove = ["fsdp", "fsdp_transpose", "context"]
bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove)

def _from_repeat_weights_to_bsw_shardmap(
Expand Down Expand Up @@ -1229,7 +1235,7 @@ def _apply_sharding_hint(weight, pspec):
weight,
sharding_name,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=0,
)

Expand All @@ -1240,21 +1246,15 @@ def _apply_sharding_hint(weight, pspec):
return _from_repeat_weights_to_bsw_hint(repeat_weights)

def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
"""Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
"""Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.

By gathering weights for `loop_iteration + 1` right now, the network communication
can overlap with the compute happening in `loop_iteration`. The dual-buffers
are returned grouped in an explicit `jax.ad_checkpoint` to strictly control memory.
can overlap with the compute happening in `loop_iteration`.
"""
cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration)
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec)
bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec)
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
return self.from_repeat_weights_to_bsw(repeat_weights, physical_partition_spec)

def run_one_iteration(
self, loop_state, bsw, weights, positions, segment_ids, deterministic, model_mode, logical_partition_spec
):
def run_one_iteration(self, loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec):
"""Executes the forward/backward logic for a single microbatch inside the pipeline.

This acts as the core step function that our `jax.lax.scan` wrappers call. It routes
Expand Down Expand Up @@ -1339,7 +1339,6 @@ def __call__(
segment_idx = None

loop_state, bsw = self.init_states(inputs)
weights = self.layers.variables
physical_partition_spec = logical_to_mesh(
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
)
Expand All @@ -1353,12 +1352,11 @@ def __call__(

logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec)

def run_iteration_scannable(model, loop_state, bsw, weights):
def run_iteration_scannable(model, loop_state, bsw):
return (
model.run_one_iteration(
loop_state,
bsw,
weights,
positions,
segment_ids,
deterministic,
Expand All @@ -1377,9 +1375,7 @@ def run_iteration_scannable(model, loop_state, bsw, weights):

# base scannable function used twice for real and bubble runs
base_scannable = functools.partial(
pipeline_utils.create_rematerialized_pipeline_stage,
model=self,
run_iteration_scannable=run_iteration_scannable,
pipeline_utils.create_pipeline_stage,
deterministic=deterministic,
model_mode=model_mode,
logical_partition_spec=logical_partition_spec,
Expand All @@ -1391,18 +1387,22 @@ def run_iteration_scannable(model, loop_state, bsw, weights):
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
run_bubbles_scannable = base_scannable(length=bubble_iterations)

if self.config.scan_pipeline_repeats:
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
)
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
(loop_state, bsw, weights), _ = run_repeats_scanned(self, (loop_state, bsw, weights))
(loop_state, bsw, weights), _ = run_bubbles_scanned(self, (loop_state, bsw, weights))
else:
for _ in range(self.config.num_pipeline_repeats):
(loop_state, bsw, weights), _ = run_one_repeat_scannable(self, (loop_state, bsw, weights))
for _ in range(bubble_iterations):
(loop_state, bsw, weights), _ = run_iteration_scannable(self, loop_state, bsw, weights)
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
pipeline_stage_fn=run_one_repeat_scannable,
length=self.config.num_pipeline_repeats,
remat_policy=self.get_pipeline_remat_policy(),
use_scan=self.config.scan_pipeline_repeats,
)
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
pipeline_stage_fn=run_bubbles_scannable,
length=1,
remat_policy=self.get_pipeline_remat_policy(),
use_scan=self.config.scan_pipeline_repeats,
)
initial_carry_repeats = (loop_state, bsw[0], self.layers.variables)
(loop_state, w_curr, pipeline_weights), _ = run_repeats_scanned(self, initial_carry_repeats)
initial_carry_bubbles = (loop_state, w_curr, pipeline_weights)
(loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles)

final_output = self.realign_output_microbatches(loop_state["state_io"])
final_output = jnp.reshape(
Expand Down
Loading
Loading