@@ -184,6 +184,7 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
184184 non-circular"""
185185 # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
186186 microbatches_processed = jnp .maximum (loop_iteration - self .forwarding_delay * jnp .arange (self .num_stages ), 0 )
187+ microbatches_processed = self ._maybe_shard_with_name (microbatches_processed , NamedSharding (self .mesh , P ("stage" )))
187188 microbatch_ids = microbatches_processed % self .config .num_pipeline_microbatches
188189 repeat_ids = microbatches_processed // self .config .num_pipeline_microbatches
189190 return microbatch_ids , repeat_ids
@@ -1006,8 +1007,12 @@ def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim):
10061007
10071008 def _gather_one (x , i ):
10081009 idx = tuple (i if d == ids_dim else slice (None ) for d in range (x .ndim ))
1009- replicated_sharding = NamedSharding (self .mesh , P ())
1010- return x .at [idx ].get (out_sharding = replicated_sharding )
1010+ positions_sharding = (
1011+ create_sharding (self .mesh , (None , "layers" , "activation_length" ))
1012+ if self .config .shard_mode == ShardMode .EXPLICIT
1013+ else None
1014+ )
1015+ return x .at [idx ].get (out_sharding = positions_sharding )
10111016
10121017 return jax .vmap (_gather_one , in_axes = (None , 0 ), out_axes = ids_dim )(xs , ids )
10131018
@@ -1168,11 +1173,12 @@ def from_repeat_weights_to_bsw(
11681173 self ,
11691174 repeat_weights ,
11701175 physical_partition_spec ,
1171- axes_to_gather = ("fsdp" , "fsdp_transpose" , "expert" ), # three major FSDP-like axes
1176+ axes_to_gather = ("fsdp" , "fsdp_transpose" , "context" , "expert" ),
1177+ # TODO (chengnuojin) set use_shardmap=true after JAX >= 10.0.0 and use all_gather(..., to='invarying')
11721178 use_shardmap = False , # using shardmap produces additional reduce-scatter in backward pass
11731179 ):
11741180 """Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
1175- axes_to_remove = ["fsdp" , "fsdp_transpose" ]
1181+ axes_to_remove = ["fsdp" , "fsdp_transpose" , "context" ]
11761182 bsw_pps = pipeline_utils .derive_stage_weight_partition_specs (physical_partition_spec , axes_to_remove )
11771183
11781184 def _from_repeat_weights_to_bsw_shardmap (
@@ -1229,7 +1235,7 @@ def _apply_sharding_hint(weight, pspec):
12291235 weight ,
12301236 sharding_name ,
12311237 shard_mode = self .config .shard_mode ,
1232- debug_sharding = self .config .shard_mode ,
1238+ debug_sharding = self .config .debug_sharding ,
12331239 extra_stack_level = 0 ,
12341240 )
12351241
@@ -1240,21 +1246,15 @@ def _apply_sharding_hint(weight, pspec):
12401246 return _from_repeat_weights_to_bsw_hint (repeat_weights )
12411247
12421248 def weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
1243- """Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
1249+ """Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
12441250
12451251 By gathering weights for `loop_iteration + 1` right now, the network communication
1246- can overlap with the compute happening in `loop_iteration`. The dual-buffers
1247- are returned grouped in an explicit `jax.ad_checkpoint` to strictly control memory.
1252+ can overlap with the compute happening in `loop_iteration`.
12481253 """
1249- cur_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration )
1250- nxt_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 )
1251- bsw_0 = self .from_repeat_weights_to_bsw (cur_repeat_weights , physical_partition_spec )
1252- bsw_1 = self .from_repeat_weights_to_bsw (nxt_repeat_weights , physical_partition_spec )
1253- return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
1254+ repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 )
1255+ return self .from_repeat_weights_to_bsw (repeat_weights , physical_partition_spec )
12541256
1255- def run_one_iteration (
1256- self , loop_state , bsw , weights , positions , segment_ids , deterministic , model_mode , logical_partition_spec
1257- ):
1257+ def run_one_iteration (self , loop_state , bsw , positions , segment_ids , deterministic , model_mode , logical_partition_spec ):
12581258 """Executes the forward/backward logic for a single microbatch inside the pipeline.
12591259
12601260 This acts as the core step function that our `jax.lax.scan` wrappers call. It routes
@@ -1339,7 +1339,6 @@ def __call__(
13391339 segment_idx = None
13401340
13411341 loop_state , bsw = self .init_states (inputs )
1342- weights = self .layers .variables
13431342 physical_partition_spec = logical_to_mesh (
13441343 logical_partition_spec , mesh = self .mesh , rules = self .config .logical_axis_rules
13451344 )
@@ -1353,12 +1352,11 @@ def __call__(
13531352
13541353 logical_partition_spec = pipeline_utils .strip_pipeline_repeat_logical_axis (logical_partition_spec )
13551354
1356- def run_iteration_scannable (model , loop_state , bsw , weights ):
1355+ def run_iteration_scannable (model , loop_state , bsw ):
13571356 return (
13581357 model .run_one_iteration (
13591358 loop_state ,
13601359 bsw ,
1361- weights ,
13621360 positions ,
13631361 segment_ids ,
13641362 deterministic ,
@@ -1377,9 +1375,7 @@ def run_iteration_scannable(model, loop_state, bsw, weights):
13771375
13781376 # base scannable function used twice for real and bubble runs
13791377 base_scannable = functools .partial (
1380- pipeline_utils .create_rematerialized_pipeline_stage ,
1381- model = self ,
1382- run_iteration_scannable = run_iteration_scannable ,
1378+ pipeline_utils .create_pipeline_stage ,
13831379 deterministic = deterministic ,
13841380 model_mode = model_mode ,
13851381 logical_partition_spec = logical_partition_spec ,
@@ -1391,18 +1387,22 @@ def run_iteration_scannable(model, loop_state, bsw, weights):
13911387 run_one_repeat_scannable = base_scannable (length = self .config .num_pipeline_microbatches )
13921388 run_bubbles_scannable = base_scannable (length = bubble_iterations )
13931389
1394- if self .config .scan_pipeline_repeats :
1395- run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
1396- pipeline_stage_fn = run_one_repeat_scannable , length = self .config .num_pipeline_repeats
1397- )
1398- run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan (pipeline_stage_fn = run_bubbles_scannable , length = 1 )
1399- (loop_state , bsw , weights ), _ = run_repeats_scanned (self , (loop_state , bsw , weights ))
1400- (loop_state , bsw , weights ), _ = run_bubbles_scanned (self , (loop_state , bsw , weights ))
1401- else :
1402- for _ in range (self .config .num_pipeline_repeats ):
1403- (loop_state , bsw , weights ), _ = run_one_repeat_scannable (self , (loop_state , bsw , weights ))
1404- for _ in range (bubble_iterations ):
1405- (loop_state , bsw , weights ), _ = run_iteration_scannable (self , loop_state , bsw , weights )
1390+ run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
1391+ pipeline_stage_fn = run_one_repeat_scannable ,
1392+ length = self .config .num_pipeline_repeats ,
1393+ remat_policy = self .get_pipeline_remat_policy (),
1394+ use_scan = self .config .scan_pipeline_repeats ,
1395+ )
1396+ run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan (
1397+ pipeline_stage_fn = run_bubbles_scannable ,
1398+ length = 1 ,
1399+ remat_policy = self .get_pipeline_remat_policy (),
1400+ use_scan = self .config .scan_pipeline_repeats ,
1401+ )
1402+ initial_carry_repeats = (loop_state , bsw [0 ], self .layers .variables )
1403+ (loop_state , w_curr , pipeline_weights ), _ = run_repeats_scanned (self , initial_carry_repeats )
1404+ initial_carry_bubbles = (loop_state , w_curr , pipeline_weights )
1405+ (loop_state , _ , pipeline_weights ), _ = run_bubbles_scanned (self , initial_carry_bubbles )
14061406
14071407 final_output = self .realign_output_microbatches (loop_state ["state_io" ])
14081408 final_output = jnp .reshape (
0 commit comments