@@ -1239,7 +1239,7 @@ def _apply_sharding_hint(weight, pspec):
12391239 return _from_repeat_weights_to_bsw_shardmap (repeat_weights , physical_partition_spec , axes_to_gather = axes_to_gather )
12401240 return _from_repeat_weights_to_bsw_hint (repeat_weights )
12411241
1242- def weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
1242+ def both_weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
12431243 """Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
12441244
12451245 By gathering weights for `loop_iteration + 1` right now, the network communication
@@ -1250,7 +1250,16 @@ def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12501250 nxt_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 )
12511251 bsw_0 = self .from_repeat_weights_to_bsw (cur_repeat_weights , physical_partition_spec )
12521252 bsw_1 = self .from_repeat_weights_to_bsw (nxt_repeat_weights , physical_partition_spec )
1253- return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
1253+ return bsw_0 , bsw_1
1254+
1255+ def one_weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
1256+ """Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
1257+
1258+ By gathering weights for `loop_iteration + 1` right now, the network communication
1259+ can overlap with the compute happening in `loop_iteration`.
1260+ """
1261+ repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 )
1262+ return self .from_repeat_weights_to_bsw (repeat_weights , physical_partition_spec )
12541263
12551264 def run_one_iteration (self , loop_state , bsw , positions , segment_ids , deterministic , model_mode , logical_partition_spec ):
12561265 """Executes the forward/backward logic for a single microbatch inside the pipeline.
@@ -1389,18 +1398,12 @@ def run_iteration_scannable(model, loop_state, bsw):
13891398 run_one_repeat_scannable = base_scannable (length = self .config .num_pipeline_microbatches )
13901399 run_bubbles_scannable = base_scannable (length = bubble_iterations )
13911400
1392- if self .config .scan_pipeline_repeats :
1393- run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
1394- pipeline_stage_fn = run_one_repeat_scannable , length = self .config .num_pipeline_repeats
1395- )
1396- run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan (pipeline_stage_fn = run_bubbles_scannable , length = 1 )
1397- (loop_state , bsw ), _ = run_repeats_scanned (self , (loop_state , bsw ))
1398- (loop_state , bsw ), _ = run_bubbles_scanned (self , (loop_state , bsw ))
1399- else :
1400- for _ in range (self .config .num_pipeline_repeats ):
1401- (loop_state , bsw ), _ = run_one_repeat_scannable (self , (loop_state , bsw ))
1402- for _ in range (bubble_iterations ):
1403- (loop_state , bsw ), _ = run_iteration_scannable (self , loop_state , bsw )
1401+ run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
1402+ pipeline_stage_fn = run_one_repeat_scannable , length = self .config .num_pipeline_repeats
1403+ )
1404+ run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan (pipeline_stage_fn = run_bubbles_scannable , length = 1 )
1405+ (loop_state , w_curr ), _ = run_repeats_scanned (self , (loop_state , bsw [0 ]))
1406+ (loop_state , _ ), _ = run_bubbles_scanned (self , (loop_state , w_curr ))
14041407
14051408 final_output = self .realign_output_microbatches (loop_state ["state_io" ])
14061409 final_output = jnp .reshape (
0 commit comments