@@ -184,6 +184,7 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
184184 non-circular"""
185185 # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
186186 microbatches_processed = jnp .maximum (loop_iteration - self .forwarding_delay * jnp .arange (self .num_stages ), 0 )
187+ microbatches_processed = self ._maybe_shard_with_name (microbatches_processed , NamedSharding (self .mesh , P ("stage" )))
187188 microbatch_ids = microbatches_processed % self .config .num_pipeline_microbatches
188189 repeat_ids = microbatches_processed // self .config .num_pipeline_microbatches
189190 return microbatch_ids , repeat_ids
@@ -1006,8 +1007,12 @@ def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim):
10061007
10071008 def _gather_one (x , i ):
10081009 idx = tuple (i if d == ids_dim else slice (None ) for d in range (x .ndim ))
1009- replicated_sharding = NamedSharding (self .mesh , P ())
1010- return x .at [idx ].get (out_sharding = replicated_sharding )
1010+ positions_sharding = (
1011+ create_sharding (self .mesh , (None , "layers" , "activation_length" ))
1012+ if self .config .shard_mode == ShardMode .EXPLICIT
1013+ else None
1014+ )
1015+ return x .at [idx ].get (out_sharding = positions_sharding )
10111016
10121017 return jax .vmap (_gather_one , in_axes = (None , 0 ), out_axes = ids_dim )(xs , ids )
10131018
@@ -1229,7 +1234,7 @@ def _apply_sharding_hint(weight, pspec):
12291234 weight ,
12301235 sharding_name ,
12311236 shard_mode = self .config .shard_mode ,
1232- debug_sharding = self .config .shard_mode ,
1237+ debug_sharding = self .config .debug_sharding ,
12331238 extra_stack_level = 0 ,
12341239 )
12351240
@@ -1239,7 +1244,7 @@ def _apply_sharding_hint(weight, pspec):
12391244 return _from_repeat_weights_to_bsw_shardmap (repeat_weights , physical_partition_spec , axes_to_gather = axes_to_gather )
12401245 return _from_repeat_weights_to_bsw_hint (repeat_weights )
12411246
1242- def weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
1247+ def both_weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
12431248 """Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
12441249
12451250 By gathering weights for `loop_iteration + 1` right now, the network communication
@@ -1250,7 +1255,16 @@ def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12501255 nxt_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 )
12511256 bsw_0 = self .from_repeat_weights_to_bsw (cur_repeat_weights , physical_partition_spec )
12521257 bsw_1 = self .from_repeat_weights_to_bsw (nxt_repeat_weights , physical_partition_spec )
1253- return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
1258+ return bsw_0 , bsw_1
1259+
1260+ def one_weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
1261+ """Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
1262+
1263+ By gathering weights for `loop_iteration + 1` right now, the network communication
1264+ can overlap with the compute happening in `loop_iteration`.
1265+ """
1266+ repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 )
1267+ return self .from_repeat_weights_to_bsw (repeat_weights , physical_partition_spec )
12541268
12551269 def run_one_iteration (self , loop_state , bsw , positions , segment_ids , deterministic , model_mode , logical_partition_spec ):
12561270 """Executes the forward/backward logic for a single microbatch inside the pipeline.
@@ -1389,18 +1403,12 @@ def run_iteration_scannable(model, loop_state, bsw):
13891403 run_one_repeat_scannable = base_scannable (length = self .config .num_pipeline_microbatches )
13901404 run_bubbles_scannable = base_scannable (length = bubble_iterations )
13911405
1392- if self .config .scan_pipeline_repeats :
1393- run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
1394- pipeline_stage_fn = run_one_repeat_scannable , length = self .config .num_pipeline_repeats
1395- )
1396- run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan (pipeline_stage_fn = run_bubbles_scannable , length = 1 )
1397- (loop_state , bsw ), _ = run_repeats_scanned (self , (loop_state , bsw ))
1398- (loop_state , bsw ), _ = run_bubbles_scanned (self , (loop_state , bsw ))
1399- else :
1400- for _ in range (self .config .num_pipeline_repeats ):
1401- (loop_state , bsw ), _ = run_one_repeat_scannable (self , (loop_state , bsw ))
1402- for _ in range (bubble_iterations ):
1403- (loop_state , bsw ), _ = run_iteration_scannable (self , loop_state , bsw )
1406+ run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
1407+ pipeline_stage_fn = run_one_repeat_scannable , length = self .config .num_pipeline_repeats
1408+ )
1409+ run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan (pipeline_stage_fn = run_bubbles_scannable , length = 1 )
1410+ (loop_state , w_curr ), _ = run_repeats_scanned (self , (loop_state , bsw [0 ]))
1411+ (loop_state , _ ), _ = run_bubbles_scanned (self , (loop_state , w_curr ))
14041412
14051413 final_output = self .realign_output_microbatches (loop_state ["state_io" ])
14061414 final_output = jnp .reshape (
0 commit comments