@@ -248,24 +248,21 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
248248 return run_pipeline_microbatches_custom
249249
250250
251- def create_rematerialized_pipeline_stage (
251+ def create_pipeline_stage (
252252 length ,
253253 deterministic ,
254254 model_mode ,
255255 logical_partition_spec ,
256256 physical_partition_spec ,
257257 positions ,
258258 segment_ids ,
259- pipeline_weights ,
260259):
261- """Builds a memory-checkpointed execution block for a single pipeline stage.
260+ """Builds an execution block for a single pipeline stage.
262261
263262 This function prepares the state for a specific chunk of pipeline execution by:
264- 1. Prefetching the required weights for the current stage/loop iteration.
265- 2. Executing `length` microbatches using either a memory-efficient `jax.lax.scan`
266- (if `scan_pipeline_iterations` is True) or an unrolled Python `for` loop.
267- 3. Wrapping the entire stage block in `flax.linen.remat` to discard and recompute
268- activations during the backward pass based on the model's policy.
263+ 1. Prefetching the required weights (e.g., FSDP-gathered) for the current stage/loop iteration.
264+ 2. Executing `length` microbatches using a memory-efficient `jax.lax.scan` via a custom VJP
265+ that manages collective communication overlap.
269266
270267 Args:
271268 length: The number of microbatches to process in this stage.
@@ -275,14 +272,27 @@ def create_rematerialized_pipeline_stage(
275272 physical_partition_spec: Rules for physical device mesh mappings (used in prefetching).
276273 positions: Position IDs for the sequence.
277274 segment_ids: Segment/Attention routing IDs for the sequence.
278- pipeline_weights: The fully gathered pipeline weights explicitly passed via closure.
279275
280276 Returns:
281- A function decorated with `nn.remat` that takes `(model, loop_state)` and returns
282- the updated `loop_state`.
277+ A function that takes `(model, carry)` and returns the updated `carry` and `None` for the scan outputs.
283278 """
284279
285- def execute_pipeline_stage_outer (model , loop_state_and_bsw ):
280+ def execute_pipeline_stage_flax (model , carry ):
281+ """
282+ A non-pure Flax closure of the pipeline stage.
283+
284+ This function bridges the pure JAX custom VJP logic with Flax's object-oriented
285+ lifting mechanisms. It unpacks the carry state and routes it through the pure VJP function.
286+
287+ Args:
288+ model: CircularPipeline Flax linen model instance.
289+ carry: A tuple containing (loop_state, w_curr, pipeline_weights).
290+ - loop_state: The current execution state of the pipeline.
291+ - w_curr: The gathered weights used for the current pipeline step.
292+ - pipeline_weights: The fully sharded baseline weights.
293+ """
294+
295+ loop_state , w_curr , pipeline_weights = carry
286296
287297 scan_microbatches_fn = create_gradient_accumulation_scan (
288298 model = model ,
@@ -292,71 +302,89 @@ def execute_pipeline_stage_outer(model, loop_state_and_bsw):
292302 logical_partition_spec = logical_partition_spec ,
293303 )
294304
295- remat_weight_prefetching = model .one_weight_prefetching
296-
305+ # Establish a pure function boundary to allow for custom VJP definition
297306 @jax .custom_vjp
298- def execute_pipeline_stage ( loop_state_and_bsw , pipeline_weights ):
299- return execute_pipeline_stage_custom_fwd ( loop_state_and_bsw , pipeline_weights )[0 ]
307+ def execute_pipeline_stage_pure ( loop_state , w_curr , pipeline_weights ):
308+ return execute_pipeline_stage_pure_fwd ( loop_state , w_curr , pipeline_weights )[0 ]
300309
301- def execute_pipeline_stage_custom_fwd (loop_state_and_bsw , pipeline_weights ):
302- loop_state , w_curr = loop_state_and_bsw
303- # # Retrieve the specific weights needed for this pipeline chunk
304- w_next = remat_weight_prefetching (
310+ def execute_pipeline_stage_pure_fwd (loop_state , w_curr , pipeline_weights ):
311+ # Prefetch FSDP-sharded weights for the upcoming pipeline repeat
312+ w_next = model .weight_prefetching (
305313 pipeline_weights ,
306314 physical_partition_spec ,
307315 loop_state ["loop_iteration" ],
308316 )
317+ # Construct a buffered sliding window (BSW) of weights.
318+ # w_curr: Weights actively used for the current microbatch steps.
319+ # w_next: Newly gathered weights that will be carried forward as the new w_curr.
309320 bsw = (w_curr , w_next )
310- p_remat_weight_prefetching = functools .partial (
311- remat_weight_prefetching ,
321+ # Bind arguments to the weight prefetching function to prepare it for linear transpose
322+ p_weight_prefetching = functools .partial (
323+ model .weight_prefetching ,
312324 physical_partition_spec = physical_partition_spec ,
313325 loop_iteration = loop_state ["loop_iteration" ],
314326 )
315- remat_weight_prefetching_t = jax .linear_transpose (
316- p_remat_weight_prefetching ,
327+ # Since weight gathering (all-gather) is a linear operation, we can derive its dual
328+ # (reduce-scatter) via jax.linear_transpose. This avoids redundant forward passes
329+ weight_prefetching_t = jax .linear_transpose (
330+ p_weight_prefetching ,
317331 pipeline_weights ,
318332 )
319- (loop_state , bsw ), scan_fn_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
320- w_curr , w_next = bsw
321- return (loop_state , w_next ), (scan_fn_vjp , remat_weight_prefetching_t )
322-
323- def execute_pipeline_stage_custom_bwd (residuals , g_outputs ):
333+ # Execute the forward pass of the microbatches and generate its VJP.
334+ # The VJP captures necessary checkpoints to evaluate gradients later.
335+ (loop_state , bsw ), scan_microbatches_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
336+ # Discard the old weights (w_curr) and advance w_next to act as the current weights in the next iteration
337+ _ , w_next = bsw
338+ return (loop_state , w_next ), (scan_microbatches_vjp , weight_prefetching_t )
339+
340+ def execute_pipeline_stage_pure_bwd (residuals , g_outputs ):
341+ # Unpack forward pass residuals (VJP closures) and the incoming output gradients
324342 g_loop_state , g_w_next = g_outputs
325- scan_fn_vjp , remat_weight_prefetching_t = residuals
343+ scan_microbatches_vjp , weight_prefetching_t = residuals
344+ # Initialize zero cotangents for w_curr, as it was consumed in the forward pass
326345 g_w_curr = jax .tree .map (jnp .zeros_like , g_w_next )
327346 g_bsw = (g_w_curr , g_w_next )
328- g_loop_state , g_bsw , _ , _ = scan_fn_vjp ((g_loop_state , g_bsw ))
347+ # Backpropagate gradients through the dual microbatch execution block
348+ g_loop_state , g_bsw , _ , _ = scan_microbatches_vjp ((g_loop_state , g_bsw ))
349+ # Apply the linear transpose of the weight prefetch to execute the reduce-scatter
350+ # This maps the gradients of the gathered weights back to the FSDP-sharded parameter space
329351 g_w_curr , g_w_next = g_bsw
330- (g_pipeline_weights ,) = remat_weight_prefetching_t (g_w_next )
331- return (g_loop_state , g_w_curr ), g_pipeline_weights
332-
333- execute_pipeline_stage .defvjp (execute_pipeline_stage_custom_fwd , execute_pipeline_stage_custom_bwd )
352+ (g_pipeline_weights ,) = weight_prefetching_t (g_w_next )
353+ # Return gradients corresponding to the three original inputs of execute_pipeline_stage_pure
354+ return g_loop_state , g_w_curr , g_pipeline_weights
334355
335- return execute_pipeline_stage (loop_state_and_bsw , pipeline_weights ), None
356+ execute_pipeline_stage_pure .defvjp (execute_pipeline_stage_pure_fwd , execute_pipeline_stage_pure_bwd )
357+ # Execute the pure pipeline stage. We unpack the two modified outputs (loop_state, w_next)
358+ # and repack them alongside the unmodified pipeline_weights to maintain a consistent carry shape for nn.scan.
359+ return (* execute_pipeline_stage_pure (loop_state , w_curr , pipeline_weights ), pipeline_weights ), None
336360
337- return execute_pipeline_stage_outer
361+ return execute_pipeline_stage_flax
338362
339363
340- def create_flax_pipeline_scan (pipeline_stage_fn , length , use_scan = True ):
341- """Wraps the pipeline stage execution in a `flax.linen.scan`.
364+ def create_flax_pipeline_scan (pipeline_stage_fn , length , remat_policy , use_scan = True ):
365+ """Wraps the pipeline stage execution in `flax.linen.remat` and `flax.linen.scan`.
342366
343- This lifts the pipeline stage function so it can be repeated sequentially over
344- the specified length. It safely handles Flax-specific state collections, ensuring
345- that metrics, intermediate values, and PRNG keys do not collide or overwrite
346- each other across the loop iterations.
367+ This explicitly wraps the pipeline step in a gradient checkpointing policy
368+ and then lifts it so it can be repeated sequentially over the specified length.
369+ It safely handles Flax-specific state collections, ensuring that metrics, intermediate
370+ values, and PRNG keys do not collide or overwrite each other across loop iterations.
347371
348372 Args:
349373 pipeline_stage_fn: The function representing a single pipeline stage
350- (usually created by `create_rematerialized_pipeline_stage`).
374+ (usually created by `create_pipeline_stage`).
375+ remat_policy: The checkpointing policy used by `nn.remat` to manage activation memory.
351376 length: The total number of pipeline stages/repeats to scan over.
352- use_scan: Either scan over repeats or unroll the scan .
377+ use_scan: Whether to use `jax.lax.scan` (True) or unroll the loop (False) .
353378
354379 Returns:
355380 A Flax scanned function that executes the full pipeline schedule.
356381 """
357382 unroll_length = 1 if use_scan else length
358383 return nn .scan (
359- pipeline_stage_fn ,
384+ nn .remat (
385+ pipeline_stage_fn ,
386+ policy = remat_policy ,
387+ ),
360388 variable_axes = {
361389 "summaries" : 0 ,
362390 "aux_loss" : 0 ,
0 commit comments