Skip to content

Commit 215a736

Browse files
committed
add bsw checkpoint
1 parent 04ffb00 commit 215a736

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/MaxText/layers/pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,9 @@ def _all_gather_invariant(x, i):
590590
return x
591591

592592
new_variables = jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx)
593+
new_variables = jax.ad_checkpoint.checkpoint_name(new_variables, "bsw_gathered_weights")
593594

594-
return (cur_bsw[1], new_variables)
595+
return jax.ad_checkpoint.checkpoint_name((cur_bsw[1], new_variables), "bsw")
595596

596597
return _all_gather_inner(repeat_weights, bsw, fsdp_idx)
597598

@@ -984,7 +985,7 @@ def run_iteration_scannable(model, loop_state):
984985
if self.config.set_remat_policy_on_pipeline_iterations:
985986
run_iteration_scannable = nn.remat(
986987
run_iteration_scannable,
987-
prevent_cse=not self.config.scan_pipeline_iterations,
988+
prevent_cse=True, # not self.config.scan_pipeline_iterations,
988989
policy=self.get_pipeline_remat_policy(),
989990
)
990991

@@ -1017,7 +1018,7 @@ def run_one_repeat_scannable(model, loop_state):
10171018

10181019
run_one_repeat_scannable = nn.remat(
10191020
run_one_repeat_scannable,
1020-
prevent_cse=not self.config.scan_pipeline_iterations,
1021+
prevent_cse=True,
10211022
policy=self.get_pipeline_remat_policy(),
10221023
)
10231024

0 commit comments

Comments
 (0)