1515"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining"""
1616
1717from typing import Any
18+ import functools
1819
19- import numpy as np
2020from maxtext .utils import pipeline_utils
2121
2222from jax import numpy as jnp
@@ -469,11 +469,8 @@ def permute_output_micro_per_stage_dim(self, output):
469469 # The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to
470470 # state_io - it will land on a different index of state_io depending on the number of iterations.
471471 microbatch_0_idx = self .iterations_to_complete_first_microbatch () % self .microbatches_per_stage
472- permutation = (
473- np .arange (self .microbatches_per_stage ) + microbatch_0_idx
474- ) % self .microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear
475- # in idx 1, etc
476- output = output [:, permutation ]
472+ output = jnp .roll (output , shift = - microbatch_0_idx , axis = 1 )
473+ output = self ._maybe_shard_with_logical (output , self .state_io_logical )
477474 return output
478475
479476 def get_current_stage_weights (
@@ -583,6 +580,77 @@ def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iterat
583580 bsw_1 = self .from_all_variables_to_bsw (weights , loop_iteration + 1 , physical_partition_spec )
584581 return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
585582
583+ def _run_initialization (
584+ self ,
585+ example_inputs ,
586+ example_segmentation ,
587+ example_position ,
588+ segment_idx ,
589+ position_idx ,
590+ deterministic ,
591+ model_mode ,
592+ ):
593+ """Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
594+ vmap_func = self .get_vmap_func_for_init ()
595+
596+ if self .config .num_pipeline_repeats > 1 :
597+ # To shard the weights on initialization for the circular pipeline we create weights of
598+ # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
599+ # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
600+ vmap_func = nn .vmap (
601+ vmap_func ,
602+ in_axes = (0 , segment_idx , position_idx , None , None ),
603+ variable_axes = {
604+ "params" : 0 ,
605+ "_overwrite_with_gradient" : 0 ,
606+ "non_trainable" : 0 ,
607+ "hyper_params" : 0 ,
608+ },
609+ split_rngs = {"params" : True , "dropout" : self .config .enable_dropout },
610+ metadata_params = {
611+ nn .PARTITION_NAME : "circular_repeats" ,
612+ "sub_weight_split_dims_mapping" : (None ,),
613+ "is_initializing" : True ,
614+ "x_times" : self .config .num_pipeline_repeats ,
615+ "optimizer_dims_mapping" : None ,
616+ },
617+ )
618+
619+ example_inputs = jax .lax .broadcast (example_inputs , [self .config .num_pipeline_repeats ])
620+ example_segmentation = (
621+ jax .lax .broadcast (example_segmentation , [self .config .num_pipeline_repeats ])
622+ if example_segmentation is not None
623+ else None
624+ )
625+ example_position = (
626+ jax .lax .broadcast (example_position , [self .config .num_pipeline_repeats ])
627+ if example_position is not None
628+ else None
629+ )
630+
631+ # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
632+ # the full total_iterations.
633+ example_inputs = self ._maybe_shard_with_logical (example_inputs , (None , None , None , None ))
634+ stage_outputs = vmap_func (
635+ self .layers , example_inputs , example_segmentation , example_position , deterministic , model_mode
636+ )
637+ if self .config .scan_layers :
638+ stage_outputs = stage_outputs [0 ]
639+
640+ # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
641+ # which has shape [pipeline_microbatch_size, sequence, embed]
642+ if self .config .num_pipeline_repeats > 1 :
643+ stage_outputs = stage_outputs [0 ] # Remove extra dimension created for the circular vmap
644+ broadcasted_stage_outpus = jax .lax .broadcast (
645+ stage_outputs [0 ], [self .config .micro_batch_size_to_train_on // self .pipeline_microbatch_size ]
646+ )
647+
648+ return jnp .reshape (
649+ broadcasted_stage_outpus ,
650+ [self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ],
651+ out_sharding = self .output_sharding ,
652+ )
653+
586654 def get_vmap_func_for_init (self ):
587655 """This vmap func is used to initialize the weights only on init."""
588656
@@ -815,63 +883,8 @@ def __call__(
815883 bubble_iterations = self .forwarding_delay * (self .num_stages - 1 )
816884
817885 if self .is_initializing ():
818- vmap_func = self .get_vmap_func_for_init ()
819-
820- if self .config .num_pipeline_repeats > 1 :
821- # To shard the weights on initialization for the circular pipeline we create weights of
822- # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
823- # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
824- vmap_func = nn .vmap (
825- vmap_func ,
826- in_axes = (0 , segment_idx , position_idx , None , None ),
827- variable_axes = {
828- "params" : 0 ,
829- "_overwrite_with_gradient" : 0 ,
830- "non_trainable" : 0 ,
831- "hyper_params" : 0 ,
832- },
833- split_rngs = {"params" : True , "dropout" : self .config .enable_dropout },
834- metadata_params = {
835- nn .PARTITION_NAME : "circular_repeats" ,
836- "sub_weight_split_dims_mapping" : (None ,),
837- "is_initializing" : True ,
838- "x_times" : self .config .num_pipeline_repeats ,
839- "optimizer_dims_mapping" : None ,
840- },
841- )
842-
843- example_inputs = jax .lax .broadcast (example_inputs , [self .config .num_pipeline_repeats ])
844- example_segmentation = (
845- jax .lax .broadcast (example_segmentation , [self .config .num_pipeline_repeats ])
846- if example_segmentation is not None
847- else None
848- )
849- example_position = (
850- jax .lax .broadcast (example_position , [self .config .num_pipeline_repeats ])
851- if example_position is not None
852- else None
853- )
854- # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
855- # the full total_iterations.
856- example_inputs = self ._maybe_shard_with_logical (example_inputs , (None , None , None , None ))
857- stage_outputs = vmap_func (
858- self .layers , example_inputs , example_segmentation , example_position , deterministic , model_mode
859- )
860- if self .config .scan_layers :
861- stage_outputs = stage_outputs [0 ]
862-
863- # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
864- # which has shape [pipeline_microbatch_size, sequence, embed]
865- if self .config .num_pipeline_repeats > 1 :
866- stage_outputs = stage_outputs [0 ] # Remove extra dimension created for the circular vmap
867- broadcasted_stage_outpus = jax .lax .broadcast (
868- stage_outputs [0 ], [self .config .micro_batch_size_to_train_on // self .pipeline_microbatch_size ]
869- )
870-
871- return jnp .reshape (
872- broadcasted_stage_outpus ,
873- [self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ],
874- out_sharding = self .output_sharding ,
886+ return self ._run_initialization (
887+ example_inputs , example_segmentation , example_position , segment_idx , position_idx , deterministic , model_mode
875888 )
876889
877890 logical_partition_spec = pipeline_utils .get_logical_spec_repeats_removed (logical_partition_spec )
@@ -898,95 +911,37 @@ def run_iteration_scannable(model, loop_state):
898911 policy = self .get_pipeline_remat_policy (),
899912 )
900913
901- def run_one_repeat_scannable (model , loop_state ):
902- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
903- loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
904- )
905-
906- if model .config .scan_pipeline_iterations :
907- run_one_repeat_scanned_custom = pipeline_utils .create_scanned_function (
908- model = model ,
909- run_iteration_scannable = run_iteration_scannable ,
910- length = model .config .num_pipeline_microbatches ,
911- variable_axes = {
912- "summaries" : 0 ,
913- "aux_loss" : 0 ,
914- "intermediates" : 0 ,
915- "hyper_params" : 0 ,
916- },
917- split_rngs = {"random" : True },
918- deterministic = deterministic ,
919- model_mode = model_mode ,
920- logical_partition_spec = logical_partition_spec ,
921- )
922- loop_state = run_one_repeat_scanned_custom (loop_state , positions , segment_ids )
923- else :
924- for _ in range (model .config .num_pipeline_microbatches ):
925- loop_state , _ = run_iteration_scannable (model , loop_state )
926- return loop_state , None
927-
928- run_one_repeat_scannable = nn .remat (
929- run_one_repeat_scannable ,
930- prevent_cse = not self .config .scan_pipeline_iterations ,
931- policy = self .get_pipeline_remat_policy (),
914+ base_scannable = functools .partial (
915+ pipeline_utils .create_run_scannable ,
916+ model = self ,
917+ run_iteration_scannable = run_iteration_scannable ,
918+ deterministic = deterministic ,
919+ model_mode = model_mode ,
920+ logical_partition_spec = logical_partition_spec ,
921+ physical_partition_spec = physical_partition_spec ,
922+ positions = positions ,
923+ segment_ids = segment_ids ,
932924 )
933925
934- def run_bubbles_scannable (model , loop_state ):
935- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
936- loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
937- )
938-
939- if model .config .scan_pipeline_iterations :
940- run_bubbles_scanned_custom = pipeline_utils .create_scanned_function (
941- model = model ,
942- run_iteration_scannable = run_iteration_scannable ,
943- length = bubble_iterations ,
944- variable_axes = {
945- "summaries" : 0 ,
946- "aux_loss" : 0 ,
947- "intermediates" : 0 ,
948- "hyper_params" : 0 ,
949- },
950- split_rngs = {"random" : True },
951- deterministic = deterministic ,
952- model_mode = model_mode ,
953- logical_partition_spec = logical_partition_spec ,
954- )
955- loop_state = run_bubbles_scanned_custom (loop_state , positions , segment_ids )
956- else :
957- for _ in range (model .config .num_pipeline_microbatches ):
958- loop_state , _ = run_iteration_scannable (model , loop_state )
959- return loop_state , None
926+ run_one_repeat_scannable = base_scannable (
927+ length = self .config .num_pipeline_microbatches ,
928+ )
960929
961- run_bubbles_scannable = nn .remat (
962- run_bubbles_scannable ,
963- prevent_cse = not self .config .scan_pipeline_iterations ,
964- policy = self .get_pipeline_remat_policy (),
930+ run_bubbles_scannable = base_scannable (
931+ length = bubble_iterations ,
965932 )
966933
967934 def run_all_iterations (model , loop_state ):
968935 if self .config .scan_pipeline_repeats :
969- run_repeats_scanned = nn .scan (
970- run_one_repeat_scannable ,
971- variable_axes = {
972- "summaries" : 0 ,
973- "aux_loss" : 0 ,
974- "intermediates" : 0 ,
975- "hyper_params" : 0 ,
976- },
977- split_rngs = {"random" : True },
936+ run_repeats_scanned = pipeline_utils .create_run_repeats_scanned (
937+ run_scannable = run_one_repeat_scannable ,
938+ model = model ,
978939 length = model .config .num_pipeline_repeats ,
979940 )
980941
981- run_bubbles_scanned = nn .scan (
982- run_bubbles_scannable ,
983- variable_axes = {
984- "summaries" : 0 ,
985- "aux_loss" : 0 ,
986- "intermediates" : 0 ,
987- "hyper_params" : 0 ,
988- },
989- split_rngs = {"random" : True },
942+ run_bubbles_scanned = pipeline_utils .create_run_repeats_scanned (
943+ run_scannable = run_bubbles_scannable ,
944+ model = model ,
990945 length = 1 ,
991946 )
992947 loop_state , _ = run_repeats_scanned (model , loop_state )
0 commit comments