Skip to content

Commit 2176bcb

Browse files
committed
add custom vjp over repeat scan
1 parent 8573240 commit 2176bcb

2 files changed

Lines changed: 49 additions & 30 deletions

File tree

src/maxtext/layers/pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,8 +1389,6 @@ def run_iteration_scannable(model, loop_state, bsw):
13891389
# base scannable function used twice for real and bubble runs
13901390
base_scannable = functools.partial(
13911391
pipeline_utils.create_rematerialized_pipeline_stage,
1392-
model=self,
1393-
run_iteration_scannable=run_iteration_scannable,
13941392
deterministic=deterministic,
13951393
model_mode=model_mode,
13961394
logical_partition_spec=logical_partition_spec,

src/maxtext/utils/pipeline_utils.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from jax.sharding import PartitionSpec as P
2020
from flax import linen as nn
2121
from flax.linen.spmd import LogicallyPartitioned
22+
import jax.numpy as jnp
2223

2324

2425
def get_mesh_axis_dim_indices(physical_partition_spec, axis_name="fsdp"):
@@ -248,8 +249,6 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248249

249250

250251
def create_rematerialized_pipeline_stage(
251-
model,
252-
run_iteration_scannable,
253252
length,
254253
deterministic,
255254
model_mode,
@@ -285,39 +284,61 @@ def create_rematerialized_pipeline_stage(
285284
the updated `loop_state`.
286285
"""
287286

288-
def execute_pipeline_stage(model, loop_state_and_bsw):
289-
loop_state, w_curr = loop_state_and_bsw
290-
# # Retrieve the specific weights needed for this pipeline chunk
291-
# bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292-
w_next = jax.remat(
293-
model.one_weight_prefetching,
294-
static_argnums=(1,),
295-
policy=jax.checkpoint_policies.nothing_saveable,
296-
)(
297-
pipeline_weights,
298-
physical_partition_spec,
299-
loop_state["loop_iteration"],
300-
)
301-
bsw = (w_curr, w_next)
287+
def execute_pipeline_stage_outer(model, loop_state_and_bsw):
288+
302289
scan_microbatches_fn = create_gradient_accumulation_scan(
303290
model=model,
304291
length=length,
305292
deterministic=deterministic,
306293
model_mode=model_mode,
307294
logical_partition_spec=logical_partition_spec,
308295
)
309-
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
310-
w_curr, w_next = bsw
311-
del w_curr
312-
return (loop_state, w_next), None
313-
314-
return execute_pipeline_stage
315-
316-
# return nn.remat(
317-
# execute_pipeline_stage,
318-
# prevent_cse=not model.config.scan_pipeline_iterations,
319-
# policy=model.get_pipeline_remat_policy(),
320-
# )
296+
297+
remat_weight_prefetching = model.one_weight_prefetching
298+
299+
@jax.custom_vjp
300+
def execute_pipeline_stage(loop_state_and_bsw, pipeline_weights):
301+
return execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights)[0]
302+
303+
def execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights):
304+
loop_state, w_curr = loop_state_and_bsw
305+
# # Retrieve the specific weights needed for this pipeline chunk
306+
w_next = remat_weight_prefetching(
307+
pipeline_weights,
308+
physical_partition_spec,
309+
loop_state["loop_iteration"],
310+
)
311+
bsw = (w_curr, w_next)
312+
313+
(loop_state, bsw), scan_fn_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
314+
p_remat_weight_prefetching = functools.partial(
315+
remat_weight_prefetching,
316+
physical_partition_spec=physical_partition_spec,
317+
loop_iteration=loop_state["loop_iteration"],
318+
)
319+
remat_weight_prefetching_t = jax.linear_transpose(
320+
p_remat_weight_prefetching,
321+
pipeline_weights,
322+
)
323+
w_curr, w_next = bsw
324+
del w_curr
325+
return (loop_state, w_next), (scan_fn_vjp, remat_weight_prefetching_t)
326+
327+
def execute_pipeline_stage_custom_bwd(residuals, g_outputs):
328+
g_loop_state, g_w_next = g_outputs
329+
scan_fn_vjp, remat_weight_prefetching_t = residuals
330+
g_w_curr = jax.tree.map(jnp.zeros_like, g_w_next)
331+
g_bsw = (g_w_curr, g_w_next)
332+
g_loop_state, g_bsw, _, _ = scan_fn_vjp((g_loop_state, g_bsw))
333+
g_w_curr, g_w_next = g_bsw
334+
(g_pipeline_weights,) = remat_weight_prefetching_t(g_w_next)
335+
return (g_loop_state, g_w_curr), g_pipeline_weights
336+
337+
execute_pipeline_stage.defvjp(execute_pipeline_stage_custom_fwd, execute_pipeline_stage_custom_bwd)
338+
339+
return execute_pipeline_stage(loop_state_and_bsw, pipeline_weights), None
340+
341+
return execute_pipeline_stage_outer
321342

322343

323344
def create_flax_pipeline_scan(pipeline_stage_fn, length):

0 commit comments

Comments
 (0)