Skip to content

Commit a949bfc

Browse files
committed
add custom vjp over repeat scan
1 parent 8573240 commit a949bfc

3 files changed

Lines changed: 72 additions & 35 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ pipeline_fsdp_ag_per_repeat: False
299299
# It may be useful to do the reverse when the layers_per_stage is very large.
300300
# The below settings only have effect when using pipeline parallelism.
301301
scan_pipeline_iterations: True
302-
scan_pipeline_repeats: True
302+
scan_pipeline_repeats: False
303303
scan_layers_per_stage: False
304304
set_remat_policy_on_pipeline_iterations: True
305305
set_remat_policy_on_layers_per_stage: False

src/maxtext/layers/pipeline.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,8 +1389,6 @@ def run_iteration_scannable(model, loop_state, bsw):
13891389
# base scannable function used twice for real and bubble runs
13901390
base_scannable = functools.partial(
13911391
pipeline_utils.create_rematerialized_pipeline_stage,
1392-
model=self,
1393-
run_iteration_scannable=run_iteration_scannable,
13941392
deterministic=deterministic,
13951393
model_mode=model_mode,
13961394
logical_partition_spec=logical_partition_spec,
@@ -1401,12 +1399,28 @@ def run_iteration_scannable(model, loop_state, bsw):
14011399
)
14021400

14031401
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
1402+
# run_one_repeat_scannable = nn.remat(
1403+
# run_one_repeat_scannable,
1404+
# prevent_cse=True,
1405+
# policy=self.get_pipeline_remat_policy()
1406+
# )
14041407
run_bubbles_scannable = base_scannable(length=bubble_iterations)
1408+
# run_bubbles_scannable = nn.remat(
1409+
# run_bubbles_scannable,
1410+
# prevent_cse=True,
1411+
# policy=self.get_pipeline_remat_policy()
1412+
# )
14051413

14061414
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1407-
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
1415+
pipeline_stage_fn=run_one_repeat_scannable,
1416+
length=self.config.num_pipeline_repeats,
1417+
use_scan=self.config.scan_pipeline_repeats,
1418+
)
1419+
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
1420+
pipeline_stage_fn=run_bubbles_scannable,
1421+
length=1,
1422+
use_scan=self.config.scan_pipeline_repeats,
14081423
)
1409-
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
14101424
(loop_state, w_curr), _ = run_repeats_scanned(self, (loop_state, bsw[0]))
14111425
(loop_state, _), _ = run_bubbles_scanned(self, (loop_state, w_curr))
14121426

src/maxtext/utils/pipeline_utils.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from jax.sharding import PartitionSpec as P
2020
from flax import linen as nn
2121
from flax.linen.spmd import LogicallyPartitioned
22+
import jax.numpy as jnp
2223

2324

2425
def get_mesh_axis_dim_indices(physical_partition_spec, axis_name="fsdp"):
@@ -248,8 +249,6 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248249

249250

250251
def create_rematerialized_pipeline_stage(
251-
model,
252-
run_iteration_scannable,
253252
length,
254253
deterministic,
255254
model_mode,
@@ -269,8 +268,6 @@ def create_rematerialized_pipeline_stage(
269268
activations during the backward pass based on the model's policy.
270269
271270
Args:
272-
model: The model instance containing configuration and prefetching logic.
273-
run_iteration_scannable: A fallback function for executing a single iteration unrolled.
274271
length: The number of microbatches to process in this stage.
275272
deterministic: Whether to run deterministically (e.g., disable dropout).
276273
model_mode: The operational mode (e.g., 'train').
@@ -285,42 +282,62 @@ def create_rematerialized_pipeline_stage(
285282
the updated `loop_state`.
286283
"""
287284

288-
def execute_pipeline_stage(model, loop_state_and_bsw):
289-
loop_state, w_curr = loop_state_and_bsw
290-
# # Retrieve the specific weights needed for this pipeline chunk
291-
# bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292-
w_next = jax.remat(
293-
model.one_weight_prefetching,
294-
static_argnums=(1,),
295-
policy=jax.checkpoint_policies.nothing_saveable,
296-
)(
297-
pipeline_weights,
298-
physical_partition_spec,
299-
loop_state["loop_iteration"],
300-
)
301-
bsw = (w_curr, w_next)
285+
def execute_pipeline_stage_outer(model, loop_state_and_bsw):
286+
302287
scan_microbatches_fn = create_gradient_accumulation_scan(
303288
model=model,
304289
length=length,
305290
deterministic=deterministic,
306291
model_mode=model_mode,
307292
logical_partition_spec=logical_partition_spec,
308293
)
309-
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
310-
w_curr, w_next = bsw
311-
del w_curr
312-
return (loop_state, w_next), None
313294

314-
return execute_pipeline_stage
295+
remat_weight_prefetching = model.one_weight_prefetching
296+
297+
@jax.custom_vjp
298+
def execute_pipeline_stage(loop_state_and_bsw, pipeline_weights):
299+
return execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights)[0]
300+
301+
def execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights):
302+
loop_state, w_curr = loop_state_and_bsw
303+
# # Retrieve the specific weights needed for this pipeline chunk
304+
w_next = remat_weight_prefetching(
305+
pipeline_weights,
306+
physical_partition_spec,
307+
loop_state["loop_iteration"],
308+
)
309+
bsw = (w_curr, w_next)
310+
p_remat_weight_prefetching = functools.partial(
311+
remat_weight_prefetching,
312+
physical_partition_spec=physical_partition_spec,
313+
loop_iteration=loop_state["loop_iteration"],
314+
)
315+
remat_weight_prefetching_t = jax.linear_transpose(
316+
p_remat_weight_prefetching,
317+
pipeline_weights,
318+
)
319+
(loop_state, bsw), scan_fn_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
320+
w_curr, w_next = bsw
321+
return (loop_state, w_next), (scan_fn_vjp, remat_weight_prefetching_t)
322+
323+
def execute_pipeline_stage_custom_bwd(residuals, g_outputs):
324+
g_loop_state, g_w_next = g_outputs
325+
scan_fn_vjp, remat_weight_prefetching_t = residuals
326+
g_w_curr = jax.tree.map(jnp.zeros_like, g_w_next)
327+
g_bsw = (g_w_curr, g_w_next)
328+
g_loop_state, g_bsw, _, _ = scan_fn_vjp((g_loop_state, g_bsw))
329+
g_w_curr, g_w_next = g_bsw
330+
(g_pipeline_weights,) = remat_weight_prefetching_t(g_w_next)
331+
return (g_loop_state, g_w_curr), g_pipeline_weights
332+
333+
execute_pipeline_stage.defvjp(execute_pipeline_stage_custom_fwd, execute_pipeline_stage_custom_bwd)
334+
335+
return execute_pipeline_stage(loop_state_and_bsw, pipeline_weights), None
315336

316-
# return nn.remat(
317-
# execute_pipeline_stage,
318-
# prevent_cse=not model.config.scan_pipeline_iterations,
319-
# policy=model.get_pipeline_remat_policy(),
320-
# )
337+
return execute_pipeline_stage_outer
321338

322339

323-
def create_flax_pipeline_scan(pipeline_stage_fn, length):
340+
def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
324341
"""Wraps the pipeline stage execution in a `flax.linen.scan`.
325342
326343
This lifts the pipeline stage function so it can be repeated sequentially over
@@ -332,10 +349,12 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
332349
pipeline_stage_fn: The function representing a single pipeline stage
333350
(usually created by `create_rematerialized_pipeline_stage`).
334351
length: The total number of pipeline stages/repeats to scan over.
352+
use_scan: Either scan over repeats or unroll the scan.
335353
336354
Returns:
337355
A Flax scanned function that executes the full pipeline schedule.
338356
"""
357+
unroll_length = 1 if use_scan else length
339358
return nn.scan(
340359
pipeline_stage_fn,
341360
variable_axes={
@@ -344,7 +363,11 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
344363
"intermediates": 0,
345364
"hyper_params": 0,
346365
},
366+
variable_broadcast=[
367+
"_overwrite_with_gradient",
368+
"non_trainable",
369+
],
347370
split_rngs={"random": True},
348371
length=length,
349-
unroll=length,
372+
unroll=unroll_length,
350373
)

0 commit comments

Comments
 (0)