1919from jax .sharding import PartitionSpec as P
2020from flax import linen as nn
2121from flax .linen .spmd import LogicallyPartitioned
22+ import jax .numpy as jnp
2223
2324
2425def get_mesh_axis_dim_indices (physical_partition_spec , axis_name = "fsdp" ):
@@ -248,8 +249,6 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248249
249250
250251def create_rematerialized_pipeline_stage (
251- model ,
252- run_iteration_scannable ,
253252 length ,
254253 deterministic ,
255254 model_mode ,
@@ -269,8 +268,6 @@ def create_rematerialized_pipeline_stage(
269268 activations during the backward pass based on the model's policy.
270269
271270 Args:
272- model: The model instance containing configuration and prefetching logic.
273- run_iteration_scannable: A fallback function for executing a single iteration unrolled.
274271 length: The number of microbatches to process in this stage.
275272 deterministic: Whether to run deterministically (e.g., disable dropout).
276273 model_mode: The operational mode (e.g., 'train').
@@ -285,42 +282,62 @@ def create_rematerialized_pipeline_stage(
285282 the updated `loop_state`.
286283 """
287284
288- def execute_pipeline_stage (model , loop_state_and_bsw ):
289- loop_state , w_curr = loop_state_and_bsw
290- # # Retrieve the specific weights needed for this pipeline chunk
291- # bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292- w_next = jax .remat (
293- model .one_weight_prefetching ,
294- static_argnums = (1 ,),
295- policy = jax .checkpoint_policies .nothing_saveable ,
296- )(
297- pipeline_weights ,
298- physical_partition_spec ,
299- loop_state ["loop_iteration" ],
300- )
301- bsw = (w_curr , w_next )
285+ def execute_pipeline_stage_outer (model , loop_state_and_bsw ):
286+
302287 scan_microbatches_fn = create_gradient_accumulation_scan (
303288 model = model ,
304289 length = length ,
305290 deterministic = deterministic ,
306291 model_mode = model_mode ,
307292 logical_partition_spec = logical_partition_spec ,
308293 )
309- loop_state , bsw = scan_microbatches_fn (loop_state , bsw , positions , segment_ids )
310- w_curr , w_next = bsw
311- del w_curr
312- return (loop_state , w_next ), None
313294
314- return execute_pipeline_stage
295+ remat_weight_prefetching = model .one_weight_prefetching
296+
297+ @jax .custom_vjp
298+ def execute_pipeline_stage (loop_state_and_bsw , pipeline_weights ):
299+ return execute_pipeline_stage_custom_fwd (loop_state_and_bsw , pipeline_weights )[0 ]
300+
301+ def execute_pipeline_stage_custom_fwd (loop_state_and_bsw , pipeline_weights ):
302+ loop_state , w_curr = loop_state_and_bsw
303+ # # Retrieve the specific weights needed for this pipeline chunk
304+ w_next = remat_weight_prefetching (
305+ pipeline_weights ,
306+ physical_partition_spec ,
307+ loop_state ["loop_iteration" ],
308+ )
309+ bsw = (w_curr , w_next )
310+ p_remat_weight_prefetching = functools .partial (
311+ remat_weight_prefetching ,
312+ physical_partition_spec = physical_partition_spec ,
313+ loop_iteration = loop_state ["loop_iteration" ],
314+ )
315+ remat_weight_prefetching_t = jax .linear_transpose (
316+ p_remat_weight_prefetching ,
317+ pipeline_weights ,
318+ )
319+ (loop_state , bsw ), scan_fn_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
320+ w_curr , w_next = bsw
321+ return (loop_state , w_next ), (scan_fn_vjp , remat_weight_prefetching_t )
322+
323+ def execute_pipeline_stage_custom_bwd (residuals , g_outputs ):
324+ g_loop_state , g_w_next = g_outputs
325+ scan_fn_vjp , remat_weight_prefetching_t = residuals
326+ g_w_curr = jax .tree .map (jnp .zeros_like , g_w_next )
327+ g_bsw = (g_w_curr , g_w_next )
328+ g_loop_state , g_bsw , _ , _ = scan_fn_vjp ((g_loop_state , g_bsw ))
329+ g_w_curr , g_w_next = g_bsw
330+ (g_pipeline_weights ,) = remat_weight_prefetching_t (g_w_next )
331+ return (g_loop_state , g_w_curr ), g_pipeline_weights
332+
333+ execute_pipeline_stage .defvjp (execute_pipeline_stage_custom_fwd , execute_pipeline_stage_custom_bwd )
334+
335+ return execute_pipeline_stage (loop_state_and_bsw , pipeline_weights ), None
315336
316- # return nn.remat(
317- # execute_pipeline_stage,
318- # prevent_cse=not model.config.scan_pipeline_iterations,
319- # policy=model.get_pipeline_remat_policy(),
320- # )
337+ return execute_pipeline_stage_outer
321338
322339
323- def create_flax_pipeline_scan (pipeline_stage_fn , length ):
340+ def create_flax_pipeline_scan (pipeline_stage_fn , length , use_scan = True ):
324341 """Wraps the pipeline stage execution in a `flax.linen.scan`.
325342
326343 This lifts the pipeline stage function so it can be repeated sequentially over
@@ -332,10 +349,12 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
332349 pipeline_stage_fn: The function representing a single pipeline stage
333350 (usually created by `create_rematerialized_pipeline_stage`).
334351 length: The total number of pipeline stages/repeats to scan over.
352+ use_scan: Either scan over repeats or unroll the scan.
335353
336354 Returns:
337355 A Flax scanned function that executes the full pipeline schedule.
338356 """
357+ unroll_length = 1 if use_scan else length
339358 return nn .scan (
340359 pipeline_stage_fn ,
341360 variable_axes = {
@@ -344,7 +363,11 @@ def create_flax_pipeline_scan(pipeline_stage_fn, length):
344363 "intermediates" : 0 ,
345364 "hyper_params" : 0 ,
346365 },
366+ variable_broadcast = [
367+ "_overwrite_with_gradient" ,
368+ "non_trainable" ,
369+ ],
347370 split_rngs = {"random" : True },
348371 length = length ,
349- unroll = length ,
372+ unroll = unroll_length ,
350373 )
0 commit comments