Skip to content

Commit 6c22238

Browse files
committed
add custom vjp
1 parent 23c2849 commit 6c22238

1 file changed

Lines changed: 124 additions & 70 deletions

File tree

src/MaxText/layers/pipeline.py

Lines changed: 124 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """
15+
"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining"""
1616

17-
import functools
17+
# import functools
1818
from typing import Any
19+
import functools
1920

2021
import numpy as np
2122

@@ -225,6 +226,7 @@ def _init_bsw_from_weights(variables):
225226
"loop_iteration": 0,
226227
"prev_outputs": prev_outputs,
227228
"bsw": bsw,
229+
"weights": self.layers.variables,
228230
}
229231
return init_loop_state
230232

@@ -455,6 +457,7 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx):
455457
"loop_iteration": loop_iteration + 1,
456458
"prev_outputs": new_prev_outputs,
457459
"bsw": loop_state["bsw"], # bsw is updated outside of this inner loop, only once per outer loop iteration
460+
"weights": loop_state["weights"], # Pass weights through
458461
}
459462
return new_loop_state
460463

@@ -469,7 +472,9 @@ def permute_output_micro_per_stage_dim(self, output):
469472
output = output[:, permutation]
470473
return output
471474

472-
def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physical_partition_spec=None):
475+
def get_current_stage_weights(
476+
self, pipeline_weights, bsw, loop_iteration, physical_partition_spec=None, is_initializing=None
477+
):
473478
"""
474479
Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g.
475480
{'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc.
@@ -479,11 +484,11 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi
479484
pipeline_weights = self._remove_logically_partition(pipeline_weights)
480485
if self.config.num_pipeline_repeats > 1:
481486
pipeline_weights = self.get_current_weights_from_bsw(
482-
bsw, loop_iteration, physical_partition_spec=physical_partition_spec
487+
bsw, loop_iteration, physical_partition_spec=physical_partition_spec, is_initializing=is_initializing
483488
)
484489
return pipeline_weights
485490

486-
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec):
491+
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None):
487492
"""Collect and gather weights from given bsw (buffer sliding window)"""
488493
bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec)
489494
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
@@ -506,10 +511,13 @@ def select_weights_from_bsw(bsw, repeat_id):
506511

507512
weights = select_weights_from_bsw(bsw, repeat_ids)
508513

514+
if is_initializing is None:
515+
is_initializing = self.is_initializing()
516+
509517
circular_metadata_params = {
510518
nn.PARTITION_NAME: "circular_repeats",
511519
"sub_weight_split_dims_mapping": (None,),
512-
"is_initializing": self.is_initializing(),
520+
"is_initializing": is_initializing,
513521
"x_times": self.config.num_pipeline_repeats,
514522
"optimizer_dims_mapping": None,
515523
}
@@ -550,7 +558,7 @@ def find_fsdp(pspec):
550558

551559
return jax.tree.map(find_fsdp, physical_partition_spec)
552560

553-
def from_all_variables_to_repeat_weights(self, loop_iteration, physical_partition_spec):
561+
def from_all_variables_to_repeat_weights(self, weights, loop_iteration, physical_partition_spec):
554562
"""Generate one single repeat weight from all variables."""
555563
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
556564

@@ -559,24 +567,24 @@ def gather_weights_for_stages_in(w, spec):
559567
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
560568
)
561569

562-
weights = self._remove_logically_partition(self.layers.variables)
570+
weights = self._remove_logically_partition(weights)
563571
if physical_partition_spec is None:
564-
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights)
572+
weights = jax.tree.map(gather_weights_for_stages_in, weights)
565573
else:
566-
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
574+
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
567575
circular_metadata_params = {
568576
nn.PARTITION_NAME: "circular_repeats",
569577
"sub_weight_split_dims_mapping": (None,),
570578
"is_initializing": self.is_initializing(),
571579
"x_times": self.config.num_pipeline_repeats,
572580
"optimizer_dims_mapping": None,
573581
}
574-
repeat_weights = meta.remove_axis(repeat_weights, 0, circular_metadata_params)
582+
repeat_weights = meta.remove_axis(weights, 0, circular_metadata_params)
575583
return repeat_weights
576584

577-
def from_all_variables_to_bsw(self, loop_iteration, physical_partition_spec):
585+
def from_all_variables_to_bsw(self, weights, loop_iteration, physical_partition_spec):
578586
"""All gather one branch of bsw using shardmap."""
579-
repeat_weights = self.from_all_variables_to_repeat_weights(loop_iteration, physical_partition_spec)
587+
repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec)
580588
bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec)
581589
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
582590
fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec)
@@ -597,10 +605,10 @@ def _all_gather_invariant(x, i):
597605

598606
return _all_gather_inner(repeat_weights, fsdp_idx)
599607

600-
def bsw_all_gather_over_fsdp(self, physical_partition_spec, loop_iteration):
608+
def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iteration):
601609
"""All gather all bsw over fsdp mesh axis using shardmap."""
602-
bsw_0 = self.from_all_variables_to_bsw(loop_iteration, physical_partition_spec)
603-
bsw_1 = self.from_all_variables_to_bsw(loop_iteration + 1, physical_partition_spec)
610+
bsw_0 = self.from_all_variables_to_bsw(weights, loop_iteration, physical_partition_spec)
611+
bsw_1 = self.from_all_variables_to_bsw(weights, loop_iteration + 1, physical_partition_spec)
604612
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
605613

606614
def get_vmap_func_for_init(self):
@@ -666,20 +674,22 @@ def func_to_vmap(
666674
def run_one_iteration(
667675
self,
668676
loop_state,
669-
pipeline_weights,
670677
positions,
671678
segment_ids,
672679
deterministic,
673680
model_mode,
674681
decoder_layer_instance,
675682
logical_partition_spec,
683+
vmap_func=None,
684+
is_initializing=None,
676685
):
677686
"""Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
678687
and update the loop state."""
679688
state_io = loop_state["state_io"]
680689
shift = loop_state["shift"]
681690
circ_storage = loop_state["circ_storage"]
682691
loop_iteration = loop_state["loop_iteration"]
692+
pipeline_weights = loop_state["weights"]
683693

684694
microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration)
685695

@@ -693,49 +703,15 @@ def run_one_iteration(
693703
stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None
694704
stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None
695705

696-
vmap_func = self.get_main_vmap_func_for_iterations()
697-
698-
if self.config.num_pipeline_repeats > 1:
699-
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
700-
701-
def prepare_vars_for_main_vmap(weights, physical_partition_spec=None):
702-
703-
circular_metadata_params = {
704-
nn.PARTITION_NAME: "circular_repeats",
705-
"sub_weight_split_dims_mapping": (None,),
706-
"is_initializing": self.is_initializing(),
707-
"x_times": self.config.num_pipeline_repeats,
708-
"optimizer_dims_mapping": None,
709-
}
710-
weights = meta.remove_axis(
711-
weights, 0, circular_metadata_params
712-
) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one
713-
# circular entry per stage.
714-
weights = self._remove_logically_partition(weights)
715-
716-
def gather_weights_for_stages_in(w, spec=None):
717-
return self.vmap_parallel_gather(
718-
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
719-
)
720-
721-
if physical_partition_spec is None:
722-
weights = jax.tree.map(gather_weights_for_stages_in, weights)
723-
else:
724-
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
725-
return weights
726-
727-
prepare_vars_for_main_vmap_partial = functools.partial(
728-
prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec
729-
)
730-
vmap_func = nn.map_variables(
731-
vmap_func,
732-
mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"],
733-
mutable=True,
734-
trans_in_fn=prepare_vars_for_main_vmap_partial,
735-
)
706+
if vmap_func is None:
707+
vmap_func = self.get_main_vmap_func_for_iterations()
736708

737709
stage_weights = self.get_current_stage_weights(
738-
pipeline_weights, loop_state["bsw"], loop_iteration, physical_partition_spec=physical_partition_spec
710+
pipeline_weights,
711+
loop_state["bsw"],
712+
loop_iteration,
713+
physical_partition_spec=physical_partition_spec,
714+
is_initializing=is_initializing,
739715
)
740716

741717
stages_output = vmap_func(
@@ -978,7 +954,6 @@ def run_iteration_scannable(model, loop_state):
978954
return (
979955
model.run_one_iteration(
980956
loop_state,
981-
model.layers.variables,
982957
positions,
983958
segment_ids,
984959
deterministic,
@@ -997,7 +972,9 @@ def run_iteration_scannable(model, loop_state):
997972
)
998973

999974
def run_one_repeat_scannable(model, loop_state):
1000-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"])
975+
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
976+
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
977+
)
1001978

1002979
if model.config.scan_pipeline_iterations:
1003980
run_one_repeat_scanned = nn.scan(
@@ -1014,7 +991,85 @@ def run_one_repeat_scannable(model, loop_state):
1014991
split_rngs={"random": True},
1015992
length=model.config.num_pipeline_microbatches,
1016993
)
1017-
loop_state, _ = run_one_repeat_scanned(model, loop_state)
994+
995+
@functools.partial(jax.custom_vjp)
996+
def run_one_repeat_scanned_custom(loop_state, positions, segment_ids):
997+
final_state, _ = run_one_repeat_scanned(model, loop_state)
998+
return final_state
999+
1000+
def run_one_repeat_scanned_custom_fwd(loop_state, positions, segment_ids):
1001+
final_state, _ = run_one_repeat_scanned(model, loop_state)
1002+
# We return loop_state as residual. model is passed to bwd as arg.
1003+
return final_state, (
1004+
loop_state,
1005+
positions,
1006+
segment_ids,
1007+
)
1008+
1009+
def run_one_repeat_scanned_custom_bwd(residuals, g_final_state):
1010+
init_loop_state, positions, segment_ids = residuals
1011+
1012+
# Re-run forward pass to get saved states (checkpointing)
1013+
def scan_body_fwd(carry, _):
1014+
new_state = model.run_one_iteration(
1015+
carry,
1016+
positions,
1017+
segment_ids,
1018+
deterministic,
1019+
model_mode,
1020+
model.layers,
1021+
logical_partition_spec=logical_partition_spec,
1022+
)
1023+
# Return lightweight state for saving (exclude bsw/weights)
1024+
saved = {k: v for k, v in carry.items() if k not in ["bsw", "weights"]}
1025+
return new_state, saved
1026+
1027+
_, saved_states = jax.lax.scan(
1028+
scan_body_fwd,
1029+
init_loop_state,
1030+
None,
1031+
length=model.config.num_pipeline_microbatches,
1032+
)
1033+
1034+
# Backward scan to accumulate gradients
1035+
def scan_body_bwd(carry, saved_slice):
1036+
d_next_state = carry
1037+
1038+
# Reconstruct current loop_state (input to step)
1039+
curr_loop_state = {
1040+
**saved_slice,
1041+
"bsw": init_loop_state["bsw"],
1042+
"weights": init_loop_state["weights"],
1043+
}
1044+
1045+
# Define function to differentiate w.r.t loop_state
1046+
def step_fn(s):
1047+
out = model.run_one_iteration(
1048+
s,
1049+
positions,
1050+
segment_ids,
1051+
deterministic,
1052+
model_mode,
1053+
model.layers,
1054+
logical_partition_spec=logical_partition_spec,
1055+
)
1056+
return out
1057+
1058+
_, vjp_fun = jax.vjp(step_fn, curr_loop_state)
1059+
1060+
# Backprop d_next_state
1061+
(d_curr_state,) = vjp_fun(d_next_state)
1062+
1063+
return d_curr_state, None
1064+
1065+
# Run backward scan
1066+
d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, saved_states, reverse=True)
1067+
1068+
return (d_init_state, None, None)
1069+
1070+
run_one_repeat_scanned_custom.defvjp(run_one_repeat_scanned_custom_fwd, run_one_repeat_scanned_custom_bwd)
1071+
1072+
loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids)
10181073
else:
10191074
for _ in range(model.config.num_pipeline_microbatches):
10201075
loop_state, _ = run_iteration_scannable(model, loop_state)
@@ -1056,7 +1111,9 @@ def run_all_iterations(model, loop_state):
10561111
length=bubble_iterations,
10571112
)
10581113
loop_state, _ = run_repeats_scanned(model, loop_state)
1059-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"])
1114+
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
1115+
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
1116+
)
10601117
loop_state, _ = run_bubbles_scanned(model, loop_state)
10611118
else:
10621119
for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop
@@ -1068,14 +1125,11 @@ def run_all_iterations(model, loop_state):
10681125
# The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized.
10691126
# if self.config.scan_pipeline_iterations:
10701127
variable_carry = []
1071-
variable_broadcast = [
1072-
"params",
1073-
"_overwrite_with_gradient",
1074-
] # All loop iterations need the weights for the full pipeline.
1075-
if self.is_mutable_collection("non_trainable"):
1076-
variable_carry.append("non_trainable")
1077-
else:
1078-
variable_broadcast.append("non_trainable")
1128+
variable_broadcast = [] # All loop iterations need the weights for the full pipeline.
1129+
# if self.is_mutable_collection("non_trainable"):
1130+
# variable_carry.append("non_trainable")
1131+
# else:
1132+
# variable_broadcast.append("non_trainable")
10791133

10801134
loop_state = run_all_iterations(self, loop_state)
10811135

0 commit comments

Comments
 (0)