@@ -194,56 +194,54 @@ def create_gradient_accumulation_scan(
194194 """
195195
196196 @functools .partial (jax .custom_vjp )
197- def run_single_microbatch_custom (lightweight_state , bsw , weights , pos_arg , seg_arg ):
198- return run_single_microbatch_custom_fwd (lightweight_state , bsw , weights , pos_arg , seg_arg )[0 ]
197+ def run_single_microbatch_custom (lightweight_state , bsw , pos_arg , seg_arg ):
198+ return run_single_microbatch_custom_fwd (lightweight_state , bsw , pos_arg , seg_arg )[0 ]
199199
200- def run_single_microbatch_custom_fwd (lightweight_state , bsw , weights , pos_arg , seg_arg ):
201- def _run (l , b , w ):
200+ def run_single_microbatch_custom_fwd (lightweight_state , bsw , pos_arg , seg_arg ):
201+ def _run (l , b ):
202202 out = model .run_one_iteration (
203- l , b , w , pos_arg , seg_arg , deterministic , model_mode , logical_partition_spec = logical_partition_spec
203+ l , b , pos_arg , seg_arg , deterministic , model_mode , logical_partition_spec = logical_partition_spec
204204 )
205- return out , b , w
205+ return out , b
206206
207207 # Rematerialize the inner step to save activation memory
208208 _run_remat = jax .remat (_run , prevent_cse = False , policy = model .get_pipeline_remat_policy ())
209- out , vjp_fun = jax .vjp (_run_remat , lightweight_state , bsw , weights )
209+ out , vjp_fun = jax .vjp (_run_remat , lightweight_state , bsw )
210210 return out , vjp_fun
211211
212212 def run_single_microbatch_custom_bwd (res , g_out ):
213213 vjp_fun = res
214- d_l , d_b , d_w = vjp_fun (g_out )
215- return d_l , d_b , d_w , None , None
214+ d_l , d_b = vjp_fun (g_out )
215+ return d_l , d_b , None , None
216216
217217 run_single_microbatch_custom .defvjp (run_single_microbatch_custom_fwd , run_single_microbatch_custom_bwd )
218218
219219 @functools .partial (jax .custom_vjp )
220- def run_pipeline_microbatches_custom (loop_state , bsw , weights , positions , segment_ids ):
221- return run_pipeline_microbatches_custom_fwd (loop_state , bsw , weights , positions , segment_ids )[0 ]
220+ def run_pipeline_microbatches_custom (loop_state , bsw , positions , segment_ids ):
221+ return run_pipeline_microbatches_custom_fwd (loop_state , bsw , positions , segment_ids )[0 ]
222222
223- def run_pipeline_microbatches_custom_fwd (loop_state , bsw , weights , positions , segment_ids ):
223+ def run_pipeline_microbatches_custom_fwd (loop_state , bsw , positions , segment_ids ):
224224 final_lightweight , scan_vjp_fun = jax .vjp (
225- lambda l , b , w : jax .lax .scan (
226- lambda carry , _ : (run_single_microbatch_custom (carry , b , w , positions , segment_ids )[0 ], None ),
225+ lambda l , b : jax .lax .scan (
226+ lambda carry , _ : (run_single_microbatch_custom (carry , b , positions , segment_ids )[0 ], None ),
227227 l ,
228228 None ,
229229 length = length ,
230230 )[0 ],
231231 loop_state ,
232232 bsw ,
233- weights ,
234233 )
235234
236- return (final_lightweight , bsw , weights ), scan_vjp_fun
235+ return (final_lightweight , bsw ), scan_vjp_fun
237236
238237 def run_pipeline_microbatches_custom_bwd (residuals , g_final_state ):
239238 scan_vjp_fun = residuals
240- g_lightweight , g_bsw , g_weights = g_final_state
241- d_init_lightweight , d_init_bsw , d_init_weights = scan_vjp_fun (g_lightweight )
239+ g_lightweight , g_bsw = g_final_state
240+ d_init_lightweight , d_init_bsw = scan_vjp_fun (g_lightweight )
242241
243242 d_init_bsw = jax .tree .map (lambda d , g : d + g if hasattr (d , "shape" ) else d , d_init_bsw , g_bsw )
244- d_init_weights = jax .tree .map (lambda d , g : d + g if hasattr (d , "shape" ) else d , d_init_weights , g_weights )
245243
246- return (d_init_lightweight , d_init_bsw , d_init_weights , None , None )
244+ return (d_init_lightweight , d_init_bsw , None , None )
247245
248246 run_pipeline_microbatches_custom .defvjp (run_pipeline_microbatches_custom_fwd , run_pipeline_microbatches_custom_bwd )
249247 return run_pipeline_microbatches_custom
@@ -259,6 +257,7 @@ def create_rematerialized_pipeline_stage(
259257 physical_partition_spec ,
260258 positions ,
261259 segment_ids ,
260+ pipeline_weights ,
262261):
263262 """Builds a memory-checkpointed execution block for a single pipeline stage.
264263
@@ -279,16 +278,17 @@ def create_rematerialized_pipeline_stage(
279278 physical_partition_spec: Rules for physical device mesh mappings (used in prefetching).
280279 positions: Position IDs for the sequence.
281280 segment_ids: Segment/Attention routing IDs for the sequence.
281+ pipeline_weights: The fully gathered pipeline weights explicitly passed via closure.
282282
283283 Returns:
284284 A function decorated with `nn.remat` that takes `(model, loop_state)` and returns
285285 the updated `loop_state`.
286286 """
287287
288- def execute_pipeline_stage (model , loop_state_and_bsw_and_weights ):
289- loop_state , bsw , weights = loop_state_and_bsw_and_weights
288+ def execute_pipeline_stage (model , loop_state_and_bsw ):
289+ loop_state , bsw = loop_state_and_bsw
290290 # Retrieve the specific weights needed for this pipeline chunk
291- bsw = model .weight_prefetching (weights , physical_partition_spec , loop_state ["loop_iteration" ])
291+ bsw = model .weight_prefetching (pipeline_weights , physical_partition_spec , loop_state ["loop_iteration" ])
292292
293293 if model .config .scan_pipeline_iterations :
294294 scan_microbatches_fn = create_gradient_accumulation_scan (
@@ -298,11 +298,11 @@ def execute_pipeline_stage(model, loop_state_and_bsw_and_weights):
298298 model_mode = model_mode ,
299299 logical_partition_spec = logical_partition_spec ,
300300 )
301- loop_state , bsw , weights = scan_microbatches_fn (loop_state , bsw , weights , positions , segment_ids )
301+ loop_state , bsw = scan_microbatches_fn (loop_state , bsw , positions , segment_ids )
302302 else :
303303 for _ in range (length ):
304- (loop_state , bsw , weights ), _ = run_iteration_scannable (model , loop_state , bsw , weights )
305- return (loop_state , bsw , weights ), None
304+ (loop_state , bsw ), _ = run_iteration_scannable (model , loop_state , bsw )
305+ return (loop_state , bsw ), None
306306
307307 return nn .remat (
308308 execute_pipeline_stage ,
0 commit comments