1919from jax .sharding import PartitionSpec as P
2020from flax import linen as nn
2121from flax .linen .spmd import LogicallyPartitioned
22+ import jax .numpy as jnp
2223
2324
2425def get_mesh_axis_dim_indices (physical_partition_spec , axis_name = "fsdp" ):
@@ -248,8 +249,6 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248249
249250
250251def create_rematerialized_pipeline_stage (
251- model ,
252- run_iteration_scannable ,
253252 length ,
254253 deterministic ,
255254 model_mode ,
@@ -285,42 +284,62 @@ def create_rematerialized_pipeline_stage(
285284 the updated `loop_state`.
286285 """
287286
288- def execute_pipeline_stage (model , loop_state_and_bsw ):
289- loop_state , w_curr = loop_state_and_bsw
290- # # Retrieve the specific weights needed for this pipeline chunk
291- # bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292- w_next = jax .remat (
293- model .one_weight_prefetching ,
294- static_argnums = (1 ,),
295- policy = jax .checkpoint_policies .nothing_saveable ,
296- )(
297- pipeline_weights ,
298- physical_partition_spec ,
299- loop_state ["loop_iteration" ],
300- )
301- bsw = (w_curr , w_next )
287+ def execute_pipeline_stage_outer (model , loop_state_and_bsw ):
288+
302289 scan_microbatches_fn = create_gradient_accumulation_scan (
303290 model = model ,
304291 length = length ,
305292 deterministic = deterministic ,
306293 model_mode = model_mode ,
307294 logical_partition_spec = logical_partition_spec ,
308295 )
309- loop_state , bsw = scan_microbatches_fn (loop_state , bsw , positions , segment_ids )
310- w_curr , w_next = bsw
311- del w_curr
312- return (loop_state , w_next ), None
313296
314- return execute_pipeline_stage
297+ remat_weight_prefetching = model .one_weight_prefetching
298+
299+ @jax .custom_vjp
300+ def execute_pipeline_stage (loop_state_and_bsw , pipeline_weights ):
301+ return execute_pipeline_stage_custom_fwd (loop_state_and_bsw , pipeline_weights )[0 ]
302+
303+ def execute_pipeline_stage_custom_fwd (loop_state_and_bsw , pipeline_weights ):
304+ loop_state , w_curr = loop_state_and_bsw
305+ # # Retrieve the specific weights needed for this pipeline chunk
306+ w_next = remat_weight_prefetching (
307+ pipeline_weights ,
308+ physical_partition_spec ,
309+ loop_state ["loop_iteration" ],
310+ )
311+ bsw = (w_curr , w_next )
312+ p_remat_weight_prefetching = functools .partial (
313+ remat_weight_prefetching ,
314+ physical_partition_spec = physical_partition_spec ,
315+ loop_iteration = loop_state ["loop_iteration" ],
316+ )
317+ remat_weight_prefetching_t = jax .linear_transpose (
318+ p_remat_weight_prefetching ,
319+ pipeline_weights ,
320+ )
321+ (loop_state , bsw ), scan_fn_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
322+ w_curr , w_next = bsw
323+ return (loop_state , w_next ), (scan_fn_vjp , remat_weight_prefetching_t )
324+
325+ def execute_pipeline_stage_custom_bwd (residuals , g_outputs ):
326+ g_loop_state , g_w_next = g_outputs
327+ scan_fn_vjp , remat_weight_prefetching_t = residuals
328+ g_w_curr = jax .tree .map (jnp .zeros_like , g_w_next )
329+ g_bsw = (g_w_curr , g_w_next )
330+ g_loop_state , g_bsw , _ , _ = scan_fn_vjp ((g_loop_state , g_bsw ))
331+ g_w_curr , g_w_next = g_bsw
332+ (g_pipeline_weights ,) = remat_weight_prefetching_t (g_w_next )
333+ return (g_loop_state , g_w_curr ), g_pipeline_weights
334+
335+ execute_pipeline_stage .defvjp (execute_pipeline_stage_custom_fwd , execute_pipeline_stage_custom_bwd )
336+
337+ return execute_pipeline_stage (loop_state_and_bsw , pipeline_weights ), None
315338
316- # return nn.remat(
317- # execute_pipeline_stage,
318- # prevent_cse=not model.config.scan_pipeline_iterations,
319- # policy=model.get_pipeline_remat_policy(),
320- # )
339+ return execute_pipeline_stage_outer
321340
322341
323- def create_flax_pipeline_scan (pipeline_stage_fn , length ):
342+ def create_flax_pipeline_scan (pipeline_stage_fn , length , use_scan = True ):
324343 """Wraps the pipeline stage execution in a `flax.linen.scan`.
325344
326345 This lifts the pipeline stage function so it can be repeated sequentially over
@@ -332,10 +351,12 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
332351 pipeline_stage_fn: The function representing a single pipeline stage
333352 (usually created by `create_rematerialized_pipeline_stage`).
334353 length: The total number of pipeline stages/repeats to scan over.
354+ use_scan: Either scan over repeats or unroll the scan.
335355
336356 Returns:
337357 A Flax scanned function that executes the full pipeline schedule.
338358 """
359+ unroll_length = 1 if use_scan else length
339360 return nn .scan (
340361 pipeline_stage_fn ,
341362 variable_axes = {
@@ -346,5 +367,5 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
346367 },
347368 split_rngs = {"random" : True },
348369 length = length ,
349- unroll = length ,
370+ unroll = unroll_length ,
350371 )
0 commit comments