@@ -590,8 +590,9 @@ def _all_gather_invariant(x, i):
590590 return x
591591
592592 new_variables = jax .tree .map (_all_gather_invariant , sharded_weights , fsdp_idx )
593+ new_variables = jax .ad_checkpoint .checkpoint_name (new_variables , "bsw_gathered_weights" )
593594
594- return ( cur_bsw [1 ], new_variables )
595+ return jax . ad_checkpoint . checkpoint_name (( cur_bsw [1 ], new_variables ), "bsw" )
595596
596597 return _all_gather_inner (repeat_weights , bsw , fsdp_idx )
597598
@@ -984,7 +985,7 @@ def run_iteration_scannable(model, loop_state):
984985 if self .config .set_remat_policy_on_pipeline_iterations :
985986 run_iteration_scannable = nn .remat (
986987 run_iteration_scannable ,
987- prevent_cse = not self .config .scan_pipeline_iterations ,
988+ prevent_cse = True , # not self.config.scan_pipeline_iterations,
988989 policy = self .get_pipeline_remat_policy (),
989990 )
990991
@@ -1017,7 +1018,7 @@ def run_one_repeat_scannable(model, loop_state):
10171018
10181019 run_one_repeat_scannable = nn .remat (
10191020 run_one_repeat_scannable ,
1020- prevent_cse = not self . config . scan_pipeline_iterations ,
1021+ prevent_cse = True ,
10211022 policy = self .get_pipeline_remat_policy (),
10221023 )
10231024
0 commit comments