Skip to content

Commit 0efb56e

Browse files
committed
add another layer of custom vjp
1 parent 5cf5c32 commit 0efb56e

3 files changed

Lines changed: 203 additions & 206 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11181118
pre_bias_logits,
11191119
self.config.use_custom_sort_vjp,
11201120
roll_to_expert_id=num_experts_per_shard * expert_shard_id,
1121+
rngs=rngs,
11211122
)
11221123

11231124
# Filter down to the group sizes that apply to only the experts in the

src/maxtext/layers/pipeline.py

Lines changed: 97 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining"""
1616

1717
from typing import Any
18+
import functools
1819

19-
import numpy as np
2020
from maxtext.utils import pipeline_utils
2121

2222
from jax import numpy as jnp
@@ -469,11 +469,8 @@ def permute_output_micro_per_stage_dim(self, output):
469469
# The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to
470470
# state_io - it will land on a different index of state_io depending on the number of iterations.
471471
microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage
472-
permutation = (
473-
np.arange(self.microbatches_per_stage) + microbatch_0_idx
474-
) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear
475-
# in idx 1, etc
476-
output = output[:, permutation]
472+
output = jnp.roll(output, shift=-microbatch_0_idx, axis=1)
473+
output = self._maybe_shard_with_logical(output, self.state_io_logical)
477474
return output
478475

479476
def get_current_stage_weights(
@@ -583,6 +580,77 @@ def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iterat
583580
bsw_1 = self.from_all_variables_to_bsw(weights, loop_iteration + 1, physical_partition_spec)
584581
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
585582

583+
def _run_initialization(
584+
self,
585+
example_inputs,
586+
example_segmentation,
587+
example_position,
588+
segment_idx,
589+
position_idx,
590+
deterministic,
591+
model_mode,
592+
):
593+
"""Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
594+
vmap_func = self.get_vmap_func_for_init()
595+
596+
if self.config.num_pipeline_repeats > 1:
597+
# To shard the weights on initialization for the circular pipeline we create weights of
598+
# shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
599+
# We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
600+
vmap_func = nn.vmap(
601+
vmap_func,
602+
in_axes=(0, segment_idx, position_idx, None, None),
603+
variable_axes={
604+
"params": 0,
605+
"_overwrite_with_gradient": 0,
606+
"non_trainable": 0,
607+
"hyper_params": 0,
608+
},
609+
split_rngs={"params": True, "dropout": self.config.enable_dropout},
610+
metadata_params={
611+
nn.PARTITION_NAME: "circular_repeats",
612+
"sub_weight_split_dims_mapping": (None,),
613+
"is_initializing": True,
614+
"x_times": self.config.num_pipeline_repeats,
615+
"optimizer_dims_mapping": None,
616+
},
617+
)
618+
619+
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
620+
example_segmentation = (
621+
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
622+
if example_segmentation is not None
623+
else None
624+
)
625+
example_position = (
626+
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
627+
if example_position is not None
628+
else None
629+
)
630+
631+
# We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
632+
# the full total_iterations.
633+
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
634+
stage_outputs = vmap_func(
635+
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
636+
)
637+
if self.config.scan_layers:
638+
stage_outputs = stage_outputs[0]
639+
640+
# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
641+
# which has shape [pipeline_microbatch_size, sequence, embed]
642+
if self.config.num_pipeline_repeats > 1:
643+
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
644+
broadcasted_stage_outpus = jax.lax.broadcast(
645+
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
646+
)
647+
648+
return jnp.reshape(
649+
broadcasted_stage_outpus,
650+
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
651+
out_sharding=self.output_sharding,
652+
)
653+
586654
def get_vmap_func_for_init(self):
587655
"""This vmap func is used to initialize the weights only on init."""
588656

@@ -815,63 +883,8 @@ def __call__(
815883
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
816884

817885
if self.is_initializing():
818-
vmap_func = self.get_vmap_func_for_init()
819-
820-
if self.config.num_pipeline_repeats > 1:
821-
# To shard the weights on initialization for the circular pipeline we create weights of
822-
# shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
823-
# We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
824-
vmap_func = nn.vmap(
825-
vmap_func,
826-
in_axes=(0, segment_idx, position_idx, None, None),
827-
variable_axes={
828-
"params": 0,
829-
"_overwrite_with_gradient": 0,
830-
"non_trainable": 0,
831-
"hyper_params": 0,
832-
},
833-
split_rngs={"params": True, "dropout": self.config.enable_dropout},
834-
metadata_params={
835-
nn.PARTITION_NAME: "circular_repeats",
836-
"sub_weight_split_dims_mapping": (None,),
837-
"is_initializing": True,
838-
"x_times": self.config.num_pipeline_repeats,
839-
"optimizer_dims_mapping": None,
840-
},
841-
)
842-
843-
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
844-
example_segmentation = (
845-
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
846-
if example_segmentation is not None
847-
else None
848-
)
849-
example_position = (
850-
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
851-
if example_position is not None
852-
else None
853-
)
854-
# We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
855-
# the full total_iterations.
856-
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
857-
stage_outputs = vmap_func(
858-
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
859-
)
860-
if self.config.scan_layers:
861-
stage_outputs = stage_outputs[0]
862-
863-
# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
864-
# which has shape [pipeline_microbatch_size, sequence, embed]
865-
if self.config.num_pipeline_repeats > 1:
866-
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
867-
broadcasted_stage_outpus = jax.lax.broadcast(
868-
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
869-
)
870-
871-
return jnp.reshape(
872-
broadcasted_stage_outpus,
873-
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
874-
out_sharding=self.output_sharding,
886+
return self._run_initialization(
887+
example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
875888
)
876889

877890
logical_partition_spec = pipeline_utils.get_logical_spec_repeats_removed(logical_partition_spec)
@@ -898,95 +911,37 @@ def run_iteration_scannable(model, loop_state):
898911
policy=self.get_pipeline_remat_policy(),
899912
)
900913

901-
def run_one_repeat_scannable(model, loop_state):
902-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
903-
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
904-
)
905-
906-
if model.config.scan_pipeline_iterations:
907-
run_one_repeat_scanned_custom = pipeline_utils.create_scanned_function(
908-
model=model,
909-
run_iteration_scannable=run_iteration_scannable,
910-
length=model.config.num_pipeline_microbatches,
911-
variable_axes={
912-
"summaries": 0,
913-
"aux_loss": 0,
914-
"intermediates": 0,
915-
"hyper_params": 0,
916-
},
917-
split_rngs={"random": True},
918-
deterministic=deterministic,
919-
model_mode=model_mode,
920-
logical_partition_spec=logical_partition_spec,
921-
)
922-
loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids)
923-
else:
924-
for _ in range(model.config.num_pipeline_microbatches):
925-
loop_state, _ = run_iteration_scannable(model, loop_state)
926-
return loop_state, None
927-
928-
run_one_repeat_scannable = nn.remat(
929-
run_one_repeat_scannable,
930-
prevent_cse=not self.config.scan_pipeline_iterations,
931-
policy=self.get_pipeline_remat_policy(),
914+
base_scannable = functools.partial(
915+
pipeline_utils.create_run_scannable,
916+
model=self,
917+
run_iteration_scannable=run_iteration_scannable,
918+
deterministic=deterministic,
919+
model_mode=model_mode,
920+
logical_partition_spec=logical_partition_spec,
921+
physical_partition_spec=physical_partition_spec,
922+
positions=positions,
923+
segment_ids=segment_ids,
932924
)
933925

934-
def run_bubbles_scannable(model, loop_state):
935-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
936-
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
937-
)
938-
939-
if model.config.scan_pipeline_iterations:
940-
run_bubbles_scanned_custom = pipeline_utils.create_scanned_function(
941-
model=model,
942-
run_iteration_scannable=run_iteration_scannable,
943-
length=bubble_iterations,
944-
variable_axes={
945-
"summaries": 0,
946-
"aux_loss": 0,
947-
"intermediates": 0,
948-
"hyper_params": 0,
949-
},
950-
split_rngs={"random": True},
951-
deterministic=deterministic,
952-
model_mode=model_mode,
953-
logical_partition_spec=logical_partition_spec,
954-
)
955-
loop_state = run_bubbles_scanned_custom(loop_state, positions, segment_ids)
956-
else:
957-
for _ in range(model.config.num_pipeline_microbatches):
958-
loop_state, _ = run_iteration_scannable(model, loop_state)
959-
return loop_state, None
926+
run_one_repeat_scannable = base_scannable(
927+
length=self.config.num_pipeline_microbatches,
928+
)
960929

961-
run_bubbles_scannable = nn.remat(
962-
run_bubbles_scannable,
963-
prevent_cse=not self.config.scan_pipeline_iterations,
964-
policy=self.get_pipeline_remat_policy(),
930+
run_bubbles_scannable = base_scannable(
931+
length=bubble_iterations,
965932
)
966933

967934
def run_all_iterations(model, loop_state):
968935
if self.config.scan_pipeline_repeats:
969-
run_repeats_scanned = nn.scan(
970-
run_one_repeat_scannable,
971-
variable_axes={
972-
"summaries": 0,
973-
"aux_loss": 0,
974-
"intermediates": 0,
975-
"hyper_params": 0,
976-
},
977-
split_rngs={"random": True},
936+
run_repeats_scanned = pipeline_utils.create_run_repeats_scanned(
937+
run_scannable=run_one_repeat_scannable,
938+
model=model,
978939
length=model.config.num_pipeline_repeats,
979940
)
980941

981-
run_bubbles_scanned = nn.scan(
982-
run_bubbles_scannable,
983-
variable_axes={
984-
"summaries": 0,
985-
"aux_loss": 0,
986-
"intermediates": 0,
987-
"hyper_params": 0,
988-
},
989-
split_rngs={"random": True},
942+
run_bubbles_scanned = pipeline_utils.create_run_repeats_scanned(
943+
run_scannable=run_bubbles_scannable,
944+
model=model,
990945
length=1,
991946
)
992947
loop_state, _ = run_repeats_scanned(model, loop_state)

0 commit comments

Comments
 (0)