Skip to content

Commit 2d66c61

Browse files
committed
add custom vjp over repeat scan
1 parent 8573240 commit 2d66c61

3 files changed

Lines changed: 58 additions & 33 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ pipeline_fsdp_ag_per_repeat: False
299299
# It may be useful to do the reverse when the layers_per_stage is very large.
300300
# The below settings only have effect when using pipeline parallelism.
301301
scan_pipeline_iterations: True
302-
scan_pipeline_repeats: True
302+
scan_pipeline_repeats: False
303303
scan_layers_per_stage: False
304304
set_remat_policy_on_pipeline_iterations: True
305305
set_remat_policy_on_layers_per_stage: False

src/maxtext/layers/pipeline.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,8 +1389,6 @@ def run_iteration_scannable(model, loop_state, bsw):
13891389
# base scannable function used twice for real and bubble runs
13901390
base_scannable = functools.partial(
13911391
pipeline_utils.create_rematerialized_pipeline_stage,
1392-
model=self,
1393-
run_iteration_scannable=run_iteration_scannable,
13941392
deterministic=deterministic,
13951393
model_mode=model_mode,
13961394
logical_partition_spec=logical_partition_spec,
@@ -1404,9 +1402,15 @@ def run_iteration_scannable(model, loop_state, bsw):
14041402
run_bubbles_scannable = base_scannable(length=bubble_iterations)
14051403

14061404
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1407-
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
1405+
pipeline_stage_fn=run_one_repeat_scannable,
1406+
length=self.config.num_pipeline_repeats,
1407+
use_scan=self.config.scan_pipeline_repeats,
1408+
)
1409+
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(
1410+
pipeline_stage_fn=run_bubbles_scannable,
1411+
length=1,
1412+
use_scan=self.config.scan_pipeline_repeats,
14081413
)
1409-
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
14101414
(loop_state, w_curr), _ = run_repeats_scanned(self, (loop_state, bsw[0]))
14111415
(loop_state, _), _ = run_bubbles_scanned(self, (loop_state, w_curr))
14121416

src/maxtext/utils/pipeline_utils.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from jax.sharding import PartitionSpec as P
2020
from flax import linen as nn
2121
from flax.linen.spmd import LogicallyPartitioned
22+
import jax.numpy as jnp
2223

2324

2425
def get_mesh_axis_dim_indices(physical_partition_spec, axis_name="fsdp"):
@@ -248,8 +249,6 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248249

249250

250251
def create_rematerialized_pipeline_stage(
251-
model,
252-
run_iteration_scannable,
253252
length,
254253
deterministic,
255254
model_mode,
@@ -285,42 +284,62 @@ def create_rematerialized_pipeline_stage(
285284
the updated `loop_state`.
286285
"""
287286

288-
def execute_pipeline_stage(model, loop_state_and_bsw):
289-
loop_state, w_curr = loop_state_and_bsw
290-
# # Retrieve the specific weights needed for this pipeline chunk
291-
# bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292-
w_next = jax.remat(
293-
model.one_weight_prefetching,
294-
static_argnums=(1,),
295-
policy=jax.checkpoint_policies.nothing_saveable,
296-
)(
297-
pipeline_weights,
298-
physical_partition_spec,
299-
loop_state["loop_iteration"],
300-
)
301-
bsw = (w_curr, w_next)
287+
def execute_pipeline_stage_outer(model, loop_state_and_bsw):
288+
302289
scan_microbatches_fn = create_gradient_accumulation_scan(
303290
model=model,
304291
length=length,
305292
deterministic=deterministic,
306293
model_mode=model_mode,
307294
logical_partition_spec=logical_partition_spec,
308295
)
309-
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
310-
w_curr, w_next = bsw
311-
del w_curr
312-
return (loop_state, w_next), None
313296

314-
return execute_pipeline_stage
297+
remat_weight_prefetching = model.one_weight_prefetching
298+
299+
@jax.custom_vjp
300+
def execute_pipeline_stage(loop_state_and_bsw, pipeline_weights):
301+
return execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights)[0]
302+
303+
def execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights):
304+
loop_state, w_curr = loop_state_and_bsw
305+
# # Retrieve the specific weights needed for this pipeline chunk
306+
w_next = remat_weight_prefetching(
307+
pipeline_weights,
308+
physical_partition_spec,
309+
loop_state["loop_iteration"],
310+
)
311+
bsw = (w_curr, w_next)
312+
p_remat_weight_prefetching = functools.partial(
313+
remat_weight_prefetching,
314+
physical_partition_spec=physical_partition_spec,
315+
loop_iteration=loop_state["loop_iteration"],
316+
)
317+
remat_weight_prefetching_t = jax.linear_transpose(
318+
p_remat_weight_prefetching,
319+
pipeline_weights,
320+
)
321+
(loop_state, bsw), scan_fn_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
322+
w_curr, w_next = bsw
323+
return (loop_state, w_next), (scan_fn_vjp, remat_weight_prefetching_t)
324+
325+
def execute_pipeline_stage_custom_bwd(residuals, g_outputs):
326+
g_loop_state, g_w_next = g_outputs
327+
scan_fn_vjp, remat_weight_prefetching_t = residuals
328+
g_w_curr = jax.tree.map(jnp.zeros_like, g_w_next)
329+
g_bsw = (g_w_curr, g_w_next)
330+
g_loop_state, g_bsw, _, _ = scan_fn_vjp((g_loop_state, g_bsw))
331+
g_w_curr, g_w_next = g_bsw
332+
(g_pipeline_weights,) = remat_weight_prefetching_t(g_w_next)
333+
return (g_loop_state, g_w_curr), g_pipeline_weights
334+
335+
execute_pipeline_stage.defvjp(execute_pipeline_stage_custom_fwd, execute_pipeline_stage_custom_bwd)
336+
337+
return execute_pipeline_stage(loop_state_and_bsw, pipeline_weights), None
315338

316-
# return nn.remat(
317-
# execute_pipeline_stage,
318-
# prevent_cse=not model.config.scan_pipeline_iterations,
319-
# policy=model.get_pipeline_remat_policy(),
320-
# )
339+
return execute_pipeline_stage_outer
321340

322341

323-
def create_flax_pipeline_scan(pipeline_stage_fn, length):
342+
def create_flax_pipeline_scan(pipeline_stage_fn, length, use_scan=True):
324343
"""Wraps the pipeline stage execution in a `flax.linen.scan`.
325344
326345
This lifts the pipeline stage function so it can be repeated sequentially over
@@ -332,10 +351,12 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
332351
pipeline_stage_fn: The function representing a single pipeline stage
333352
(usually created by `create_rematerialized_pipeline_stage`).
334353
length: The total number of pipeline stages/repeats to scan over.
354+
use_scan: Either scan over repeats or unroll the scan.
335355
336356
Returns:
337357
A Flax scanned function that executes the full pipeline schedule.
338358
"""
359+
unroll_length = 1 if use_scan else length
339360
return nn.scan(
340361
pipeline_stage_fn,
341362
variable_axes={
@@ -346,5 +367,5 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
346367
},
347368
split_rngs={"random": True},
348369
length=length,
349-
unroll=length,
370+
unroll=unroll_length,
350371
)

0 commit comments

Comments
 (0)