Skip to content

Commit 2286894

Browse files
committed
split bsw all gather into two
1 parent 215a736 commit 2286894

1 file changed

Lines changed: 26 additions & 22 deletions

File tree

src/MaxText/layers/pipeline.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -550,24 +550,20 @@ def find_fsdp(pspec):
550550

551551
return jax.tree.map(find_fsdp, physical_partition_spec)
552552

553-
def bsw_all_gather_over_fsdp(self, weights, bsw, physical_partition_spec, loop_iteration):
554-
"""All gather bsw over fsdp mesh axis using shardmap."""
555-
bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec)
556-
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
557-
fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec)
558-
559-
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration + 1)
553+
def from_all_variables_to_repeat_weights(self, loop_iteration, physical_partition_spec):
554+
"""Generate one single repeat weight from all variables."""
555+
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
560556

561557
def gather_weights_for_stages_in(w, spec):
562558
return self.vmap_parallel_gather(
563559
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
564560
)
565561

562+
weights = self._remove_logically_partition(self.layers.variables)
566563
if physical_partition_spec is None:
567564
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights)
568565
else:
569566
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
570-
571567
circular_metadata_params = {
572568
nn.PARTITION_NAME: "circular_repeats",
573569
"sub_weight_split_dims_mapping": (None,),
@@ -576,25 +572,36 @@ def gather_weights_for_stages_in(w, spec):
576572
"optimizer_dims_mapping": None,
577573
}
578574
repeat_weights = meta.remove_axis(repeat_weights, 0, circular_metadata_params)
575+
return repeat_weights
576+
577+
def from_all_variables_to_bsw(self, loop_iteration, physical_partition_spec):
578+
"""All gather one branch of bsw using shardmap."""
579+
repeat_weights = self.from_all_variables_to_repeat_weights(loop_iteration, physical_partition_spec)
580+
bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec)
581+
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
582+
fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec)
579583

580584
@jax.shard_map(
581585
mesh=self.mesh,
582-
in_specs=(repeat_weights_pps, (bsw_pps, bsw_pps), None),
583-
out_specs=(bsw_pps, bsw_pps),
586+
in_specs=(repeat_weights_pps, None),
587+
out_specs=bsw_pps,
584588
check_vma=True,
585589
)
586-
def _all_gather_inner(sharded_weights, cur_bsw, fsdp_idx):
590+
def _all_gather_inner(sharded_weights, fsdp_idx):
587591
def _all_gather_invariant(x, i):
588592
if i >= 0:
589593
return all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True)
590594
return x
591595

592-
new_variables = jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx)
593-
new_variables = jax.ad_checkpoint.checkpoint_name(new_variables, "bsw_gathered_weights")
596+
return jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx)
594597

595-
return jax.ad_checkpoint.checkpoint_name((cur_bsw[1], new_variables), "bsw")
598+
return _all_gather_inner(repeat_weights, fsdp_idx)
596599

597-
return _all_gather_inner(repeat_weights, bsw, fsdp_idx)
600+
def bsw_all_gather_over_fsdp(self, physical_partition_spec, loop_iteration):
601+
"""All gather all bsw over fsdp mesh axis using shardmap."""
602+
bsw_0 = self.from_all_variables_to_bsw(loop_iteration, physical_partition_spec)
603+
bsw_1 = self.from_all_variables_to_bsw(loop_iteration + 1, physical_partition_spec)
604+
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
598605

599606
def get_vmap_func_for_init(self):
600607
"""This vmap func is used to initialize the weights only on init."""
@@ -985,15 +992,12 @@ def run_iteration_scannable(model, loop_state):
985992
if self.config.set_remat_policy_on_pipeline_iterations:
986993
run_iteration_scannable = nn.remat(
987994
run_iteration_scannable,
988-
prevent_cse=True, # not self.config.scan_pipeline_iterations,
995+
prevent_cse=not self.config.scan_pipeline_iterations,
989996
policy=self.get_pipeline_remat_policy(),
990997
)
991998

992999
def run_one_repeat_scannable(model, loop_state):
993-
weights = model._remove_logically_partition(model.layers.variables) # pylint: disable=protected-access
994-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
995-
weights, loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"]
996-
)
1000+
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"])
9971001

9981002
if model.config.scan_pipeline_iterations:
9991003
run_one_repeat_scanned = nn.scan(
@@ -1018,7 +1022,7 @@ def run_one_repeat_scannable(model, loop_state):
10181022

10191023
run_one_repeat_scannable = nn.remat(
10201024
run_one_repeat_scannable,
1021-
prevent_cse=True,
1025+
prevent_cse=not self.config.scan_pipeline_iterations,
10221026
policy=self.get_pipeline_remat_policy(),
10231027
)
10241028

@@ -1052,7 +1056,7 @@ def run_all_iterations(model, loop_state):
10521056
length=bubble_iterations,
10531057
)
10541058
loop_state, _ = run_repeats_scanned(model, loop_state)
1055-
loop_state["bsw"] = (loop_state["bsw"][1], jax.tree.map(jnp.zeros_like, loop_state["bsw"][1]))
1059+
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"])
10561060
loop_state, _ = run_bubbles_scanned(model, loop_state)
10571061
else:
10581062
for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop

0 commit comments

Comments
 (0)