@@ -1173,7 +1173,7 @@ def from_repeat_weights_to_bsw(
11731173 self ,
11741174 repeat_weights ,
11751175 physical_partition_spec ,
1176- axes_to_gather = ("fsdp" , "fsdp_transpose" , "expert" ), # three major FSDP-like axes
1176+ axes_to_gather = ("fsdp" , "fsdp_transpose" , "context" , " expert" ), # three major FSDP-like axes
11771177 use_shardmap = False , # using shardmap produces additional reduce-scatter in backward pass
11781178 ):
11791179 """Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW."""
@@ -1244,20 +1244,7 @@ def _apply_sharding_hint(weight, pspec):
12441244 return _from_repeat_weights_to_bsw_shardmap (repeat_weights , physical_partition_spec , axes_to_gather = axes_to_gather )
12451245 return _from_repeat_weights_to_bsw_hint (repeat_weights )
12461246
1247- def both_weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
1248- """Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
1249-
1250- By gathering weights for `loop_iteration + 1` right now, the network communication
1251- can overlap with the compute happening in `loop_iteration`. The dual-buffers
1252- are returned grouped in an explicit `jax.ad_checkpoint` to strictly control memory.
1253- """
1254- cur_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration )
1255- nxt_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 )
1256- bsw_0 = self .from_repeat_weights_to_bsw (cur_repeat_weights , physical_partition_spec )
1257- bsw_1 = self .from_repeat_weights_to_bsw (nxt_repeat_weights , physical_partition_spec )
1258- return bsw_0 , bsw_1
1259-
1260- def one_weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
1247+ def weight_prefetching (self , weights , physical_partition_spec , loop_iteration ):
12611248 """Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
12621249
12631250 By gathering weights for `loop_iteration + 1` right now, the network communication
@@ -1351,7 +1338,6 @@ def __call__(
13511338 segment_idx = None
13521339
13531340 loop_state , bsw = self .init_states (inputs )
1354- weights = self .layers .variables
13551341 physical_partition_spec = logical_to_mesh (
13561342 logical_partition_spec , mesh = self .mesh , rules = self .config .logical_axis_rules
13571343 )
@@ -1388,41 +1374,34 @@ def run_iteration_scannable(model, loop_state, bsw):
13881374
13891375 # base scannable function used twice for real and bubble runs
13901376 base_scannable = functools .partial (
1391- pipeline_utils .create_rematerialized_pipeline_stage ,
1377+ pipeline_utils .create_pipeline_stage ,
13921378 deterministic = deterministic ,
13931379 model_mode = model_mode ,
13941380 logical_partition_spec = logical_partition_spec ,
13951381 physical_partition_spec = physical_partition_spec ,
13961382 positions = positions ,
13971383 segment_ids = segment_ids ,
1398- pipeline_weights = weights ,
13991384 )
14001385
14011386 run_one_repeat_scannable = base_scannable (length = self .config .num_pipeline_microbatches )
1402- # run_one_repeat_scannable = nn.remat(
1403- # run_one_repeat_scannable,
1404- # prevent_cse=True,
1405- # policy=self.get_pipeline_remat_policy()
1406- # )
14071387 run_bubbles_scannable = base_scannable (length = bubble_iterations )
1408- # run_bubbles_scannable = nn.remat(
1409- # run_bubbles_scannable,
1410- # prevent_cse=True,
1411- # policy=self.get_pipeline_remat_policy()
1412- # )
14131388
14141389 run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
14151390 pipeline_stage_fn = run_one_repeat_scannable ,
14161391 length = self .config .num_pipeline_repeats ,
1392+ remat_policy = self .get_pipeline_remat_policy (),
14171393 use_scan = self .config .scan_pipeline_repeats ,
14181394 )
14191395 run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan (
14201396 pipeline_stage_fn = run_bubbles_scannable ,
14211397 length = 1 ,
1398+ remat_policy = self .get_pipeline_remat_policy (),
14221399 use_scan = self .config .scan_pipeline_repeats ,
14231400 )
1424- (loop_state , w_curr ), _ = run_repeats_scanned (self , (loop_state , bsw [0 ]))
1425- (loop_state , _ ), _ = run_bubbles_scanned (self , (loop_state , w_curr ))
1401+ initial_carry_repeats = (loop_state , bsw [0 ], self .layers .variables )
1402+ (loop_state , w_curr , pipeline_weights ), _ = run_repeats_scanned (self , initial_carry_repeats )
1403+ initial_carry_bubbles = (loop_state , w_curr , pipeline_weights )
1404+ (loop_state , _ , pipeline_weights ), _ = run_bubbles_scanned (self , initial_carry_bubbles )
14261405
14271406 final_output = self .realign_output_microbatches (loop_state ["state_io" ])
14281407 final_output = jnp .reshape (
0 commit comments