Skip to content

Commit 286e066

Browse files
committed
add another layer of custom vjp
1 parent 6d94f5a commit 286e066

2 files changed

Lines changed: 211 additions & 188 deletions

File tree

src/maxtext/layers/pipeline.py

Lines changed: 97 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,77 @@ def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iterat
582582
bsw_1 = self.from_all_variables_to_bsw(weights, loop_iteration + 1, physical_partition_spec)
583583
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
584584

585+
def _run_initialization(
586+
self,
587+
example_inputs,
588+
example_segmentation,
589+
example_position,
590+
segment_idx,
591+
position_idx,
592+
deterministic,
593+
model_mode,
594+
):
595+
"""Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
596+
vmap_func = self.get_vmap_func_for_init()
597+
598+
if self.config.num_pipeline_repeats > 1:
599+
# To shard the weights on initialization for the circular pipeline we create weights of
600+
# shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
601+
# We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
602+
vmap_func = nn.vmap(
603+
vmap_func,
604+
in_axes=(0, segment_idx, position_idx, None, None),
605+
variable_axes={
606+
"params": 0,
607+
"_overwrite_with_gradient": 0,
608+
"non_trainable": 0,
609+
"hyper_params": 0,
610+
},
611+
split_rngs={"params": True, "dropout": self.config.enable_dropout},
612+
metadata_params={
613+
nn.PARTITION_NAME: "circular_repeats",
614+
"sub_weight_split_dims_mapping": (None,),
615+
"is_initializing": True,
616+
"x_times": self.config.num_pipeline_repeats,
617+
"optimizer_dims_mapping": None,
618+
},
619+
)
620+
621+
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
622+
example_segmentation = (
623+
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
624+
if example_segmentation is not None
625+
else None
626+
)
627+
example_position = (
628+
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
629+
if example_position is not None
630+
else None
631+
)
632+
633+
# We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
634+
# the full total_iterations.
635+
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
636+
stage_outputs = vmap_func(
637+
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
638+
)
639+
if self.config.scan_layers:
640+
stage_outputs = stage_outputs[0]
641+
642+
# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
643+
# which has shape [pipeline_microbatch_size, sequence, embed]
644+
if self.config.num_pipeline_repeats > 1:
645+
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
646+
broadcasted_stage_outpus = jax.lax.broadcast(
647+
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
648+
)
649+
650+
return jnp.reshape(
651+
broadcasted_stage_outpus,
652+
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
653+
out_sharding=self.output_sharding,
654+
)
655+
585656
def get_vmap_func_for_init(self):
586657
"""This vmap func is used to initialize the weights only on init."""
587658

@@ -814,63 +885,8 @@ def __call__(
814885
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
815886

816887
if self.is_initializing():
817-
vmap_func = self.get_vmap_func_for_init()
818-
819-
if self.config.num_pipeline_repeats > 1:
820-
# To shard the weights on initialization for the circular pipeline we create weights of
821-
# shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
822-
# We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
823-
vmap_func = nn.vmap(
824-
vmap_func,
825-
in_axes=(0, segment_idx, position_idx, None, None),
826-
variable_axes={
827-
"params": 0,
828-
"_overwrite_with_gradient": 0,
829-
"non_trainable": 0,
830-
"hyper_params": 0,
831-
},
832-
split_rngs={"params": True, "dropout": self.config.enable_dropout},
833-
metadata_params={
834-
nn.PARTITION_NAME: "circular_repeats",
835-
"sub_weight_split_dims_mapping": (None,),
836-
"is_initializing": True,
837-
"x_times": self.config.num_pipeline_repeats,
838-
"optimizer_dims_mapping": None,
839-
},
840-
)
841-
842-
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
843-
example_segmentation = (
844-
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
845-
if example_segmentation is not None
846-
else None
847-
)
848-
example_position = (
849-
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
850-
if example_position is not None
851-
else None
852-
)
853-
# We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
854-
# the full total_iterations.
855-
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
856-
stage_outputs = vmap_func(
857-
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
858-
)
859-
if self.config.scan_layers:
860-
stage_outputs = stage_outputs[0]
861-
862-
# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
863-
# which has shape [pipeline_microbatch_size, sequence, embed]
864-
if self.config.num_pipeline_repeats > 1:
865-
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
866-
broadcasted_stage_outpus = jax.lax.broadcast(
867-
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
868-
)
869-
870-
return jnp.reshape(
871-
broadcasted_stage_outpus,
872-
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
873-
out_sharding=self.output_sharding,
888+
return self._run_initialization(
889+
example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
874890
)
875891

876892
logical_partition_spec = pipeline_utils.get_logical_spec_repeats_removed(logical_partition_spec)
@@ -897,95 +913,39 @@ def run_iteration_scannable(model, loop_state):
897913
policy=self.get_pipeline_remat_policy(),
898914
)
899915

900-
def run_one_repeat_scannable(model, loop_state):
901-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
902-
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
903-
)
904-
905-
if model.config.scan_pipeline_iterations:
906-
run_one_repeat_scanned_custom = pipeline_utils.create_scanned_function(
907-
model=model,
908-
run_iteration_scannable=run_iteration_scannable,
909-
length=model.config.num_pipeline_microbatches,
910-
variable_axes={
911-
"summaries": 0,
912-
"aux_loss": 0,
913-
"intermediates": 0,
914-
"hyper_params": 0,
915-
},
916-
split_rngs={"random": True},
917-
deterministic=deterministic,
918-
model_mode=model_mode,
919-
logical_partition_spec=logical_partition_spec,
920-
)
921-
loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids)
922-
else:
923-
for _ in range(model.config.num_pipeline_microbatches):
924-
loop_state, _ = run_iteration_scannable(model, loop_state)
925-
return loop_state, None
926-
927-
run_one_repeat_scannable = nn.remat(
928-
run_one_repeat_scannable,
929-
prevent_cse=not self.config.scan_pipeline_iterations,
930-
policy=self.get_pipeline_remat_policy(),
916+
run_one_repeat_scannable = pipeline_utils.create_run_scannable(
917+
model=self,
918+
run_iteration_scannable=run_iteration_scannable,
919+
length=self.config.num_pipeline_microbatches,
920+
deterministic=deterministic,
921+
model_mode=model_mode,
922+
logical_partition_spec=logical_partition_spec,
923+
physical_partition_spec=physical_partition_spec,
924+
positions=positions,
925+
segment_ids=segment_ids,
931926
)
932927

933-
def run_bubbles_scannable(model, loop_state):
934-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
935-
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
936-
)
937-
938-
if model.config.scan_pipeline_iterations:
939-
run_bubbles_scanned_custom = pipeline_utils.create_scanned_function(
940-
model=model,
941-
run_iteration_scannable=run_iteration_scannable,
942-
length=bubble_iterations,
943-
variable_axes={
944-
"summaries": 0,
945-
"aux_loss": 0,
946-
"intermediates": 0,
947-
"hyper_params": 0,
948-
},
949-
split_rngs={"random": True},
950-
deterministic=deterministic,
951-
model_mode=model_mode,
952-
logical_partition_spec=logical_partition_spec,
953-
)
954-
loop_state = run_bubbles_scanned_custom(loop_state, positions, segment_ids)
955-
else:
956-
for _ in range(model.config.num_pipeline_microbatches):
957-
loop_state, _ = run_iteration_scannable(model, loop_state)
958-
return loop_state, None
959-
960-
run_bubbles_scannable = nn.remat(
961-
run_bubbles_scannable,
962-
prevent_cse=not self.config.scan_pipeline_iterations,
963-
policy=self.get_pipeline_remat_policy(),
928+
run_bubbles_scannable = pipeline_utils.create_run_scannable(
929+
model=self,
930+
run_iteration_scannable=run_iteration_scannable,
931+
length=bubble_iterations,
932+
deterministic=deterministic,
933+
model_mode=model_mode,
934+
logical_partition_spec=logical_partition_spec,
935+
physical_partition_spec=physical_partition_spec,
936+
positions=positions,
937+
segment_ids=segment_ids,
964938
)
965939

966940
def run_all_iterations(model, loop_state):
967941
if self.config.scan_pipeline_repeats:
968-
run_repeats_scanned = nn.scan(
969-
run_one_repeat_scannable,
970-
variable_axes={
971-
"summaries": 0,
972-
"aux_loss": 0,
973-
"intermediates": 0,
974-
"hyper_params": 0,
975-
},
976-
split_rngs={"random": True},
942+
run_repeats_scanned = pipeline_utils.create_run_repeats_scanned(
943+
run_scannable=run_one_repeat_scannable,
977944
length=model.config.num_pipeline_repeats,
978945
)
979946

980-
run_bubbles_scanned = nn.scan(
981-
run_bubbles_scannable,
982-
variable_axes={
983-
"summaries": 0,
984-
"aux_loss": 0,
985-
"intermediates": 0,
986-
"hyper_params": 0,
987-
},
988-
split_rngs={"random": True},
947+
run_bubbles_scanned = pipeline_utils.create_run_repeats_scanned(
948+
run_scannable=run_bubbles_scannable,
989949
length=1,
990950
)
991951
loop_state, _ = run_repeats_scanned(model, loop_state)

0 commit comments

Comments
 (0)