Skip to content

Commit c38fa86

Browse files
Merge pull request #3412 from AI-Hypercomputer:chengnuojin-pp-more
PiperOrigin-RevId: 893641594
2 parents 3b4244e + 049ba3f commit c38fa86

5 files changed

Lines changed: 179 additions & 117 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ pipeline_fsdp_ag_per_repeat: False
306306
# It may be useful to do the reverse when the layers_per_stage is very large.
307307
# The below settings only have effect when using pipeline parallelism.
308308
scan_pipeline_iterations: True
309-
scan_pipeline_repeats: True
309+
scan_pipeline_repeats: False
310310
scan_layers_per_stage: False
311311
set_remat_policy_on_pipeline_iterations: True
312312
set_remat_policy_on_layers_per_stage: False

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,48 +20,49 @@
2020
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
2121
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
2222
# second, it may be required for DCN communication.
23+
#
24+
# The `context` axis is used for supporting fractional per device batch size
2325
#
2426
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
2527
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
2628
# store prefetched weights.
27-
mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
28-
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
29+
mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']
30+
data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
2931
logical_axis_rules: [
3032
['activation_batch', ['data', 'fsdp', 'expert']],
3133
['activation_batch_moe', ['data', 'fsdp', 'expert']],
3234
['activation_batch_no_exp', ['data', 'fsdp']],
3335
['activation_batch_no_exp_moe', ['data', 'fsdp']],
3436
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
35-
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
37+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
3638
['activation_heads', ['tensor']],
3739
['activation_kv_heads', ['tensor']],
38-
['activation_length', ['expert']],
39-
['activation_attn_length', ['expert']],
40-
['activation_q_length', ['expert']],
40+
['activation_length', ['context', 'expert']],
41+
['activation_attn_length', ['context', 'expert']],
42+
['activation_q_length', ['context', 'expert']],
4143
['activation_attn_embed', ['tensor']],
44+
['activation_norm_length', ['context']],
45+
['activation_norm_length_moe', ['context']],
4246
['activation_embed', ['tensor']],
4347
['activation_embed_moe', ['tensor']],
4448
['activation_mlp', ['tensor']],
4549
['activation_kv', ['tensor']],
46-
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
4750
['activation_kv_batch', ['data', 'fsdp', 'expert']],
4851
['activation_kv_batch_no_exp', ['data', 'fsdp']],
4952
['activation_kv_head_dim', ['tensor']],
5053
['activation_vocab', ['tensor']],
5154
['activation_stage', 'stage'],
5255
['activation_exp', ['expert']],
53-
['decode_batch', ['data', 'fsdp', 'expert']],
5456
['mlp', ['tensor']],
5557
['mlp_no_fsdp', ['tensor']],
5658
['vocab', ['tensor']],
5759
['heads', ['tensor']],
5860
['q_heads', ['tensor']],
5961
['kv_heads', ['tensor']],
60-
['embed', ['fsdp', 'expert']],
62+
['embed', ['fsdp', 'expert']], # remove context from embed sharding
6163
['embed_moe', ['fsdp', 'expert']],
6264
['embed_no_exp', ['fsdp']],
6365
['embed_no_exp_moe', ['fsdp']],
64-
['embed_moe', ['fsdp']],
6566
['q_lora', ['fsdp']],
6667
['kv_lora', ['fsdp']],
6768
['norm', ['tensor']],

src/maxtext/layers/pipeline.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
184184
non-circular"""
185185
# Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
186186
microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0)
187+
microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage")))
187188
microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches
188189
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
189190
return microbatch_ids, repeat_ids
@@ -1006,8 +1007,12 @@ def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim):
10061007

10071008
def _gather_one(x, i):
10081009
idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim))
1009-
replicated_sharding = NamedSharding(self.mesh, P())
1010-
return x.at[idx].get(out_sharding=replicated_sharding)
1010+
positions_sharding = (
1011+
create_sharding(self.mesh, (None, "layers", "activation_length"))
1012+
if self.config.shard_mode == ShardMode.EXPLICIT
1013+
else None
1014+
)
1015+
return x.at[idx].get(out_sharding=positions_sharding)
10111016

10121017
return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
10131018

@@ -1168,11 +1173,12 @@ def from_repeat_weights_to_bsw(
11681173
self,
11691174
repeat_weights,
11701175
physical_partition_spec,
1171-
axes_to_gather=("fsdp", "fsdp_transpose", "expert"), # three major FSDP-like axes
1176+
axes_to_gather=("fsdp", "fsdp_transpose", "context", "expert"),
1177+
# TODO (chengnuojin) set use_shardmap=true after JAX >= 10.0.0 and use all_gather(..., to='invarying')
11721178
use_shardmap=False, # using shardmap produces additional reduce-scatter in backward pass
11731179
):
11741180
"""Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
1175-
axes_to_remove = ["fsdp", "fsdp_transpose"]
1181+
axes_to_remove = ["fsdp", "fsdp_transpose", "context"]
11761182
bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove)
11771183

11781184
def _from_repeat_weights_to_bsw_shardmap(
@@ -1229,7 +1235,7 @@ def _apply_sharding_hint(weight, pspec):
12291235
weight,
12301236
sharding_name,
12311237
shard_mode=self.config.shard_mode,
1232-
debug_sharding=self.config.shard_mode,
1238+
debug_sharding=self.config.debug_sharding,
12331239
extra_stack_level=0,
12341240
)
12351241

@@ -1240,21 +1246,15 @@ def _apply_sharding_hint(weight, pspec):
12401246
return _from_repeat_weights_to_bsw_hint(repeat_weights)
12411247

12421248
def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1243-
"""Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
1249+
"""Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
12441250
12451251
By gathering weights for `loop_iteration + 1` right now, the network communication
1246-
can overlap with the compute happening in `loop_iteration`. The dual-buffers
1247-
are returned grouped in an explicit `jax.ad_checkpoint` to strictly control memory.
1252+
can overlap with the compute happening in `loop_iteration`.
12481253
"""
1249-
cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration)
1250-
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
1251-
bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec)
1252-
bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec)
1253-
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
1254+
repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
1255+
return self.from_repeat_weights_to_bsw(repeat_weights, physical_partition_spec)
12541256

1255-
def run_one_iteration(
1256-
self, loop_state, bsw, weights, positions, segment_ids, deterministic, model_mode, logical_partition_spec
1257-
):
1257+
def run_one_iteration(self, loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec):
12581258
"""Executes the forward/backward logic for a single microbatch inside the pipeline.
12591259
12601260
This acts as the core step function that our `jax.lax.scan` wrappers call. It routes
@@ -1339,7 +1339,6 @@ def __call__(
13391339
segment_idx = None
13401340

13411341
loop_state, bsw = self.init_states(inputs)
1342-
weights = self.layers.variables
13431342
physical_partition_spec = logical_to_mesh(
13441343
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
13451344
)
@@ -1353,12 +1352,11 @@ def __call__(
13531352

13541353
logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec)
13551354

1356-
def run_iteration_scannable(model, loop_state, bsw, weights):
1355+
def run_iteration_scannable(model, loop_state, bsw):
13571356
return (
13581357
model.run_one_iteration(
13591358
loop_state,
13601359
bsw,
1361-
weights,
13621360
positions,
13631361
segment_ids,
13641362
deterministic,
@@ -1377,9 +1375,7 @@ def run_iteration_scannable(model, loop_state, bsw, weights):
13771375

13781376
# base scannable function used twice for real and bubble runs
13791377
base_scannable = functools.partial(
1380-
pipeline_utils.create_rematerialized_pipeline_stage,
1381-
model=self,
1382-
run_iteration_scannable=run_iteration_scannable,
1378+
pipeline_utils.create_pipeline_stage,
13831379
deterministic=deterministic,
13841380
model_mode=model_mode,
13851381
logical_partition_spec=logical_partition_spec,
@@ -1391,18 +1387,22 @@ def run_iteration_scannable(model, loop_state, bsw, weights):
13911387
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
13921388
run_bubbles_scannable = base_scannable(length=bubble_iterations)
13931389

1394-
if self.config.scan_pipeline_repeats:
1395-
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1396-
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
1397-
)
1398-
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
1399-
(loop_state, bsw, weights), _ = run_repeats_scanned(self, (loop_state, bsw, weights))
1400-
(loop_state, bsw, weights), _ = run_bubbles_scanned(self, (loop_state, bsw, weights))
1401-
else:
1402-
for _ in range(self.config.num_pipeline_repeats):
1403-
(loop_state, bsw, weights), _ = run_one_repeat_scannable(self, (loop_state, bsw, weights))
1404-
for _ in range(bubble_iterations):
1405-
(loop_state, bsw, weights), _ = run_iteration_scannable(self, loop_state, bsw, weights)
1390+
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1391+
pipeline_stage_fn=run_one_repeat_scannable,
1392+
length=self.config.num_pipeline_repeats,
1393+
remat_policy=self.get_pipeline_remat_policy(),
1394+
use_scan=self.config.scan_pipeline_repeats,
1395+
)
1396+
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
1397+
pipeline_stage_fn=run_bubbles_scannable,
1398+
length=1,
1399+
remat_policy=self.get_pipeline_remat_policy(),
1400+
use_scan=self.config.scan_pipeline_repeats,
1401+
)
1402+
initial_carry_repeats = (loop_state, bsw[0], self.layers.variables)
1403+
(loop_state, w_curr, pipeline_weights), _ = run_repeats_scanned(self, initial_carry_repeats)
1404+
initial_carry_bubbles = (loop_state, w_curr, pipeline_weights)
1405+
(loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles)
14061406

14071407
final_output = self.realign_output_microbatches(loop_state["state_io"])
14081408
final_output = jnp.reshape(

0 commit comments

Comments
 (0)