Skip to content

Commit a1f6db3

Browse files
committed
add cse remat
1 parent a949bfc commit a1f6db3

4 files changed

Lines changed: 90 additions & 36 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ xprof_e2e_enable_fw_power_level_event: False
957957
xprof_e2e_enable_fw_thermal_event: False
958958
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
959959

960-
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
960+
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
961961
debug_sharding: False # Prints model weights sharding info
962962

963963
# Checkpoint Structured logging
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This logical rule is designed to optimize pipeline parallelism for large-scale jobs.
16+
# Key changes include removing expert weight sharding on the `q_lora` dimension, which
17+
# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when
18+
# EP x FSDP > 512.
19+
#
20+
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
21+
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
22+
# second, it may be required for DCN communication.
23+
#
24+
# Finally, the `context` axis is used to add fractional batch size support
25+
mesh_axes: ['data', 'stage', 'fsdp', 'context', 'expert']
26+
data_sharding: [['data', 'stage', 'fsdp', 'context', 'expert']]
27+
logical_axis_rules: [
28+
['activation_batch', ['data', 'fsdp', 'expert']],
29+
['activation_batch_moe', ['data', 'fsdp', 'expert']],
30+
['activation_batch_no_exp', ['data', 'fsdp']],
31+
['activation_batch_no_exp_moe', ['data', 'fsdp']],
32+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
33+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
34+
['activation_length', ['context', 'expert']],
35+
['activation_attn_length', ['context', 'expert']],
36+
['activation_attn_length_no_exp', ['context']],
37+
['activation_length_no_exp', ['context']],
38+
['activation_length_no_exp_moe', ['context']],
39+
['activation_norm_length', ['context']],
40+
['activation_norm_length_moe', ['context']],
41+
['activation_q_length', ['context', 'expert']],
42+
['activation_q_length_no_exp', ['context']],
43+
['prefill_activation_length', ['context']],
44+
['prefill_activation_norm_length', ['context']],
45+
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
46+
['activation_kv_batch', ['data', 'fsdp', 'expert']],
47+
['activation_kv_batch_no_exp', ['data', 'fsdp']],
48+
['activation_vocab', ['context']],
49+
['activation_stage', 'stage'],
50+
['activation_exp', ['expert']],
51+
['decode_batch', ['data', 'fsdp', 'expert']],
52+
['embed', ['fsdp', 'context', 'expert']],
53+
['embed_no_exp', ['fsdp', 'context']],
54+
['embed_moe', ['fsdp', 'context', 'expert']],
55+
['embed_no_exp_moe', ['fsdp', 'context']],
56+
['q_lora', ['fsdp']],
57+
['kv_lora', ['fsdp']],
58+
['layers', 'stage'],
59+
['exp', 'expert'],
60+
['exp_with_fsdp', 'fsdp'],
61+
]

src/maxtext/layers/pipeline.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ def from_repeat_weights_to_bsw(
11731173
self,
11741174
repeat_weights,
11751175
physical_partition_spec,
1176-
axes_to_gather=("fsdp", "fsdp_transpose", "expert"), # three major FSDP-like axes
1176+
axes_to_gather=("fsdp", "fsdp_transpose", "context", "expert"), # three major FSDP-like axes
11771177
use_shardmap=False, # using shardmap produces additional reduce-scatter in backward pass
11781178
):
11791179
"""Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
@@ -1351,7 +1351,6 @@ def __call__(
13511351
segment_idx = None
13521352

13531353
loop_state, bsw = self.init_states(inputs)
1354-
weights = self.layers.variables
13551354
physical_partition_spec = logical_to_mesh(
13561355
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
13571356
)
@@ -1388,41 +1387,34 @@ def run_iteration_scannable(model, loop_state, bsw):
13881387

13891388
# base scannable function used twice for real and bubble runs
13901389
base_scannable = functools.partial(
1391-
pipeline_utils.create_rematerialized_pipeline_stage,
1390+
pipeline_utils.create_pipeline_stage,
13921391
deterministic=deterministic,
13931392
model_mode=model_mode,
13941393
logical_partition_spec=logical_partition_spec,
13951394
physical_partition_spec=physical_partition_spec,
13961395
positions=positions,
13971396
segment_ids=segment_ids,
1398-
pipeline_weights=weights,
13991397
)
14001398

14011399
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
1402-
# run_one_repeat_scannable = nn.remat(
1403-
# run_one_repeat_scannable,
1404-
# prevent_cse=True,
1405-
# policy=self.get_pipeline_remat_policy()
1406-
# )
14071400
run_bubbles_scannable = base_scannable(length=bubble_iterations)
1408-
# run_bubbles_scannable = nn.remat(
1409-
# run_bubbles_scannable,
1410-
# prevent_cse=True,
1411-
# policy=self.get_pipeline_remat_policy()
1412-
# )
14131401

14141402
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
14151403
pipeline_stage_fn=run_one_repeat_scannable,
14161404
length=self.config.num_pipeline_repeats,
1405+
remat_policy=self.get_pipeline_remat_policy(),
14171406
use_scan=self.config.scan_pipeline_repeats,
14181407
)
14191408
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
14201409
pipeline_stage_fn=run_bubbles_scannable,
14211410
length=1,
1411+
remat_policy=self.get_pipeline_remat_policy(),
14221412
use_scan=self.config.scan_pipeline_repeats,
14231413
)
1424-
(loop_state, w_curr), _ = run_repeats_scanned(self, (loop_state, bsw[0]))
1425-
(loop_state, _), _ = run_bubbles_scanned(self, (loop_state, w_curr))
1414+
initial_carry_repeats = (loop_state, bsw[0], self.layers.variables)
1415+
(loop_state, w_curr, pipeline_weights), _ = run_repeats_scanned(self, initial_carry_repeats)
1416+
initial_carry_bubbles = (loop_state, w_curr, pipeline_weights)
1417+
(loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles)
14261418

14271419
final_output = self.realign_output_microbatches(loop_state["state_io"])
14281420
final_output = jnp.reshape(

src/maxtext/utils/pipeline_utils.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -248,24 +248,21 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248248
return run_pipeline_microbatches_custom
249249

250250

251-
def create_rematerialized_pipeline_stage(
251+
def create_pipeline_stage(
252252
length,
253253
deterministic,
254254
model_mode,
255255
logical_partition_spec,
256256
physical_partition_spec,
257257
positions,
258258
segment_ids,
259-
pipeline_weights,
260259
):
261-
"""Builds a memory-checkpointed execution block for a single pipeline stage.
260+
"""Builds an execution block for a single pipeline stage.
262261
263262
This function prepares the state for a specific chunk of pipeline execution by:
264263
1. Prefetching the required weights for the current stage/loop iteration.
265264
2. Executing `length` microbatches using either a memory-efficient `jax.lax.scan`
266265
(if `scan_pipeline_iterations` is True) or an unrolled Python `for` loop.
267-
3. Wrapping the entire stage block in `flax.linen.remat` to discard and recompute
268-
activations during the backward pass based on the model's policy.
269266
270267
Args:
271268
length: The number of microbatches to process in this stage.
@@ -275,14 +272,15 @@ def create_rematerialized_pipeline_stage(
275272
physical_partition_spec: Rules for physical device mesh mappings (used in prefetching).
276273
positions: Position IDs for the sequence.
277274
segment_ids: Segment/Attention routing IDs for the sequence.
278-
pipeline_weights: The fully gathered pipeline weights explicitly passed via closure.
279275
280276
Returns:
281-
A function decorated with `nn.remat` that takes `(model, loop_state)` and returns
282-
the updated `loop_state`.
277+
A function that takes `(model, loop_state, weight, pipeline_weights)` and returns
278+
the updated loop_state and new weight.
283279
"""
284280

285-
def execute_pipeline_stage_outer(model, loop_state_and_bsw):
281+
def execute_pipeline_stage_outer(model, carry):
282+
283+
loop_state, w_curr, pipeline_weights = carry
286284

287285
scan_microbatches_fn = create_gradient_accumulation_scan(
288286
model=model,
@@ -295,12 +293,11 @@ def execute_pipeline_stage_outer(model, loop_state_and_bsw):
295293
remat_weight_prefetching = model.one_weight_prefetching
296294

297295
@jax.custom_vjp
298-
def execute_pipeline_stage(loop_state_and_bsw, pipeline_weights):
299-
return execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights)[0]
296+
def execute_pipeline_stage(loop_state, w_curr, pipeline_weights):
297+
return execute_pipeline_stage_custom_fwd(loop_state, w_curr, pipeline_weights)[0]
300298

301-
def execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights):
302-
loop_state, w_curr = loop_state_and_bsw
303-
# # Retrieve the specific weights needed for this pipeline chunk
299+
def execute_pipeline_stage_custom_fwd(loop_state, w_curr, pipeline_weights):
300+
# Retrieve the specific weights needed for this pipeline chunk
304301
w_next = remat_weight_prefetching(
305302
pipeline_weights,
306303
physical_partition_spec,
@@ -328,17 +325,17 @@ def execute_pipeline_stage_custom_bwd(residuals, g_outputs):
328325
g_loop_state, g_bsw, _, _ = scan_fn_vjp((g_loop_state, g_bsw))
329326
g_w_curr, g_w_next = g_bsw
330327
(g_pipeline_weights,) = remat_weight_prefetching_t(g_w_next)
331-
return (g_loop_state, g_w_curr), g_pipeline_weights
328+
return g_loop_state, g_w_curr, g_pipeline_weights
332329

333330
execute_pipeline_stage.defvjp(execute_pipeline_stage_custom_fwd, execute_pipeline_stage_custom_bwd)
334331

335-
return execute_pipeline_stage(loop_state_and_bsw, pipeline_weights), None
332+
return (*execute_pipeline_stage(loop_state, w_curr, pipeline_weights), pipeline_weights), None
336333

337334
return execute_pipeline_stage_outer
338335

339336

340-
def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
341-
"""Wraps the pipeline stage execution in a `flax.linen.scan`.
337+
def create_flax_pipeline_scan(pipeline_stage_fn, length, remat_policy, use_scan=True):
338+
"""Wraps the pipeline stage execution in a `flax.linen.scan` and `flax.linen.remat`.
342339
343340
This lifts the pipeline stage function so it can be repeated sequentially over
344341
the specified length. It safely handles Flax-specific state collections, ensuring
@@ -348,6 +345,7 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
348345
Args:
349346
pipeline_stage_fn: The function representing a single pipeline stage
350347
(usually created by `create_rematerialized_pipeline_stage`).
348+
remat_policy: remat policy used for pipeline stage
351349
length: The total number of pipeline stages/repeats to scan over.
352350
use_scan: Either scan over repeats or unroll the scan.
353351
@@ -356,7 +354,10 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
356354
"""
357355
unroll_length = 1 if use_scan else length
358356
return nn.scan(
359-
pipeline_stage_fn,
357+
nn.remat(
358+
pipeline_stage_fn,
359+
policy=remat_policy,
360+
),
360361
variable_axes={
361362
"summaries": 0,
362363
"aux_loss": 0,

0 commit comments

Comments
 (0)