Skip to content

Commit efda1e8

Browse files
committed
add cse remat
1 parent a949bfc commit efda1e8

3 files changed

Lines changed: 31 additions & 51 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ xprof_e2e_enable_fw_power_level_event: False
957957
xprof_e2e_enable_fw_thermal_event: False
958958
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
959959

960-
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
960+
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
961961
debug_sharding: False # Prints model weights sharding info
962962

963963
# Checkpoint Structured logging

src/maxtext/layers/pipeline.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ def from_repeat_weights_to_bsw(
11731173
self,
11741174
repeat_weights,
11751175
physical_partition_spec,
1176-
axes_to_gather=("fsdp", "fsdp_transpose", "expert"), # three major FSDP-like axes
1176+
axes_to_gather=("fsdp", "fsdp_transpose", "context", "expert"), # three major FSDP-like axes
11771177
use_shardmap=False, # using shardmap produces additional reduce-scatter in backward pass
11781178
):
11791179
"""Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
@@ -1244,20 +1244,7 @@ def _apply_sharding_hint(weight, pspec):
12441244
return _from_repeat_weights_to_bsw_shardmap(repeat_weights, physical_partition_spec, axes_to_gather=axes_to_gather)
12451245
return _from_repeat_weights_to_bsw_hint(repeat_weights)
12461246

1247-
def both_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1248-
"""Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
1249-
1250-
By gathering weights for `loop_iteration + 1` right now, the network communication
1251-
can overlap with the compute happening in `loop_iteration`. The dual-buffers
1252-
are returned grouped in an explicit `jax.ad_checkpoint` to strictly control memory.
1253-
"""
1254-
cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration)
1255-
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
1256-
bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec)
1257-
bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec)
1258-
return bsw_0, bsw_1
1259-
1260-
def one_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1247+
def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12611248
"""Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
12621249
12631250
By gathering weights for `loop_iteration + 1` right now, the network communication
@@ -1351,7 +1338,6 @@ def __call__(
13511338
segment_idx = None
13521339

13531340
loop_state, bsw = self.init_states(inputs)
1354-
weights = self.layers.variables
13551341
physical_partition_spec = logical_to_mesh(
13561342
logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules
13571343
)
@@ -1388,41 +1374,34 @@ def run_iteration_scannable(model, loop_state, bsw):
13881374

13891375
# base scannable function used twice for real and bubble runs
13901376
base_scannable = functools.partial(
1391-
pipeline_utils.create_rematerialized_pipeline_stage,
1377+
pipeline_utils.create_pipeline_stage,
13921378
deterministic=deterministic,
13931379
model_mode=model_mode,
13941380
logical_partition_spec=logical_partition_spec,
13951381
physical_partition_spec=physical_partition_spec,
13961382
positions=positions,
13971383
segment_ids=segment_ids,
1398-
pipeline_weights=weights,
13991384
)
14001385

14011386
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
1402-
# run_one_repeat_scannable = nn.remat(
1403-
# run_one_repeat_scannable,
1404-
# prevent_cse=True,
1405-
# policy=self.get_pipeline_remat_policy()
1406-
# )
14071387
run_bubbles_scannable = base_scannable(length=bubble_iterations)
1408-
# run_bubbles_scannable = nn.remat(
1409-
# run_bubbles_scannable,
1410-
# prevent_cse=True,
1411-
# policy=self.get_pipeline_remat_policy()
1412-
# )
14131388

14141389
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
14151390
pipeline_stage_fn=run_one_repeat_scannable,
14161391
length=self.config.num_pipeline_repeats,
1392+
remat_policy=self.get_pipeline_remat_policy(),
14171393
use_scan=self.config.scan_pipeline_repeats,
14181394
)
14191395
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
14201396
pipeline_stage_fn=run_bubbles_scannable,
14211397
length=1,
1398+
remat_policy=self.get_pipeline_remat_policy(),
14221399
use_scan=self.config.scan_pipeline_repeats,
14231400
)
1424-
(loop_state, w_curr), _ = run_repeats_scanned(self, (loop_state, bsw[0]))
1425-
(loop_state, _), _ = run_bubbles_scanned(self, (loop_state, w_curr))
1401+
initial_carry_repeats = (loop_state, bsw[0], self.layers.variables)
1402+
(loop_state, w_curr, pipeline_weights), _ = run_repeats_scanned(self, initial_carry_repeats)
1403+
initial_carry_bubbles = (loop_state, w_curr, pipeline_weights)
1404+
(loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles)
14261405

14271406
final_output = self.realign_output_microbatches(loop_state["state_io"])
14281407
final_output = jnp.reshape(

src/maxtext/utils/pipeline_utils.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -248,24 +248,21 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248248
return run_pipeline_microbatches_custom
249249

250250

251-
def create_rematerialized_pipeline_stage(
251+
def create_pipeline_stage(
252252
length,
253253
deterministic,
254254
model_mode,
255255
logical_partition_spec,
256256
physical_partition_spec,
257257
positions,
258258
segment_ids,
259-
pipeline_weights,
260259
):
261-
"""Builds a memory-checkpointed execution block for a single pipeline stage.
260+
"""Builds an execution block for a single pipeline stage.
262261
263262
This function prepares the state for a specific chunk of pipeline execution by:
264263
1. Prefetching the required weights for the current stage/loop iteration.
265264
2. Executing `length` microbatches using either a memory-efficient `jax.lax.scan`
266265
(if `scan_pipeline_iterations` is True) or an unrolled Python `for` loop.
267-
3. Wrapping the entire stage block in `flax.linen.remat` to discard and recompute
268-
activations during the backward pass based on the model's policy.
269266
270267
Args:
271268
length: The number of microbatches to process in this stage.
@@ -275,14 +272,15 @@ def create_rematerialized_pipeline_stage(
275272
physical_partition_spec: Rules for physical device mesh mappings (used in prefetching).
276273
positions: Position IDs for the sequence.
277274
segment_ids: Segment/Attention routing IDs for the sequence.
278-
pipeline_weights: The fully gathered pipeline weights explicitly passed via closure.
279275
280276
Returns:
281-
A function decorated with `nn.remat` that takes `(model, loop_state)` and returns
282-
the updated `loop_state`.
277+
A function that takes `(model, loop_state, weight, pipeline_weights)` and returns
278+
the updated loop_state and new weight.
283279
"""
284280

285-
def execute_pipeline_stage_outer(model, loop_state_and_bsw):
281+
def execute_pipeline_stage_outer(model, carry):
282+
283+
loop_state, w_curr, pipeline_weights = carry
286284

287285
scan_microbatches_fn = create_gradient_accumulation_scan(
288286
model=model,
@@ -292,15 +290,14 @@ def execute_pipeline_stage_outer(model, loop_state_and_bsw):
292290
logical_partition_spec=logical_partition_spec,
293291
)
294292

295-
remat_weight_prefetching = model.one_weight_prefetching
293+
remat_weight_prefetching = model.weight_prefetching
296294

297295
@jax.custom_vjp
298-
def execute_pipeline_stage(loop_state_and_bsw, pipeline_weights):
299-
return execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights)[0]
296+
def execute_pipeline_stage(loop_state, w_curr, pipeline_weights):
297+
return execute_pipeline_stage_custom_fwd(loop_state, w_curr, pipeline_weights)[0]
300298

301-
def execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights):
302-
loop_state, w_curr = loop_state_and_bsw
303-
# # Retrieve the specific weights needed for this pipeline chunk
299+
def execute_pipeline_stage_custom_fwd(loop_state, w_curr, pipeline_weights):
300+
# Retrieve the specific weights needed for this pipeline chunk
304301
w_next = remat_weight_prefetching(
305302
pipeline_weights,
306303
physical_partition_spec,
@@ -328,17 +325,17 @@ def execute_pipeline_stage_custom_bwd(residuals, g_outputs):
328325
g_loop_state, g_bsw, _, _ = scan_fn_vjp((g_loop_state, g_bsw))
329326
g_w_curr, g_w_next = g_bsw
330327
(g_pipeline_weights,) = remat_weight_prefetching_t(g_w_next)
331-
return (g_loop_state, g_w_curr), g_pipeline_weights
328+
return g_loop_state, g_w_curr, g_pipeline_weights
332329

333330
execute_pipeline_stage.defvjp(execute_pipeline_stage_custom_fwd, execute_pipeline_stage_custom_bwd)
334331

335-
return execute_pipeline_stage(loop_state_and_bsw, pipeline_weights), None
332+
return (*execute_pipeline_stage(loop_state, w_curr, pipeline_weights), pipeline_weights), None
336333

337334
return execute_pipeline_stage_outer
338335

339336

340-
def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
341-
"""Wraps the pipeline stage execution in a `flax.linen.scan`.
337+
def create_flax_pipeline_scan(pipeline_stage_fn, length, remat_policy, use_scan=True):
338+
"""Wraps the pipeline stage execution in a `flax.linen.scan` and `flax.linen.remat`.
342339
343340
This lifts the pipeline stage function so it can be repeated sequentially over
344341
the specified length. It safely handles Flax-specific state collections, ensuring
@@ -348,6 +345,7 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
348345
Args:
349346
pipeline_stage_fn: The function representing a single pipeline stage
350347
(usually created by `create_rematerialized_pipeline_stage`).
348+
remat_policy: remat policy used for pipeline stage
351349
length: The total number of pipeline stages/repeats to scan over.
352350
use_scan: Either scan over repeats or unroll the scan.
353351
@@ -356,7 +354,10 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
356354
"""
357355
unroll_length = 1 if use_scan else length
358356
return nn.scan(
359-
pipeline_stage_fn,
357+
nn.remat(
358+
pipeline_stage_fn,
359+
policy=remat_policy,
360+
),
360361
variable_axes={
361362
"summaries": 0,
362363
"aux_loss": 0,

0 commit comments

Comments
 (0)