@@ -582,6 +582,77 @@ def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iterat
582582 bsw_1 = self .from_all_variables_to_bsw (weights , loop_iteration + 1 , physical_partition_spec )
583583 return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
584584
585+ def _run_initialization (
586+ self ,
587+ example_inputs ,
588+ example_segmentation ,
589+ example_position ,
590+ segment_idx ,
591+ position_idx ,
592+ deterministic ,
593+ model_mode ,
594+ ):
595+ """Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
596+ vmap_func = self .get_vmap_func_for_init ()
597+
598+ if self .config .num_pipeline_repeats > 1 :
599+ # To shard the weights on initialization for the circular pipeline we create weights of
600+ # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
601+ # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
602+ vmap_func = nn .vmap (
603+ vmap_func ,
604+ in_axes = (0 , segment_idx , position_idx , None , None ),
605+ variable_axes = {
606+ "params" : 0 ,
607+ "_overwrite_with_gradient" : 0 ,
608+ "non_trainable" : 0 ,
609+ "hyper_params" : 0 ,
610+ },
611+ split_rngs = {"params" : True , "dropout" : self .config .enable_dropout },
612+ metadata_params = {
613+ nn .PARTITION_NAME : "circular_repeats" ,
614+ "sub_weight_split_dims_mapping" : (None ,),
615+ "is_initializing" : True ,
616+ "x_times" : self .config .num_pipeline_repeats ,
617+ "optimizer_dims_mapping" : None ,
618+ },
619+ )
620+
621+ example_inputs = jax .lax .broadcast (example_inputs , [self .config .num_pipeline_repeats ])
622+ example_segmentation = (
623+ jax .lax .broadcast (example_segmentation , [self .config .num_pipeline_repeats ])
624+ if example_segmentation is not None
625+ else None
626+ )
627+ example_position = (
628+ jax .lax .broadcast (example_position , [self .config .num_pipeline_repeats ])
629+ if example_position is not None
630+ else None
631+ )
632+
633+ # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
634+ # the full total_iterations.
635+ example_inputs = self ._maybe_shard_with_logical (example_inputs , (None , None , None , None ))
636+ stage_outputs = vmap_func (
637+ self .layers , example_inputs , example_segmentation , example_position , deterministic , model_mode
638+ )
639+ if self .config .scan_layers :
640+ stage_outputs = stage_outputs [0 ]
641+
642+ # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
643+ # which has shape [pipeline_microbatch_size, sequence, embed]
644+ if self .config .num_pipeline_repeats > 1 :
645+ stage_outputs = stage_outputs [0 ] # Remove extra dimension created for the circular vmap
646+ broadcasted_stage_outpus = jax .lax .broadcast (
647+ stage_outputs [0 ], [self .config .micro_batch_size_to_train_on // self .pipeline_microbatch_size ]
648+ )
649+
650+ return jnp .reshape (
651+ broadcasted_stage_outpus ,
652+ [self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ],
653+ out_sharding = self .output_sharding ,
654+ )
655+
585656 def get_vmap_func_for_init (self ):
586657 """This vmap func is used to initialize the weights only on init."""
587658
@@ -814,63 +885,8 @@ def __call__(
814885 bubble_iterations = self .forwarding_delay * (self .num_stages - 1 )
815886
816887 if self .is_initializing ():
817- vmap_func = self .get_vmap_func_for_init ()
818-
819- if self .config .num_pipeline_repeats > 1 :
820- # To shard the weights on initialization for the circular pipeline we create weights of
821- # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
822- # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
823- vmap_func = nn .vmap (
824- vmap_func ,
825- in_axes = (0 , segment_idx , position_idx , None , None ),
826- variable_axes = {
827- "params" : 0 ,
828- "_overwrite_with_gradient" : 0 ,
829- "non_trainable" : 0 ,
830- "hyper_params" : 0 ,
831- },
832- split_rngs = {"params" : True , "dropout" : self .config .enable_dropout },
833- metadata_params = {
834- nn .PARTITION_NAME : "circular_repeats" ,
835- "sub_weight_split_dims_mapping" : (None ,),
836- "is_initializing" : True ,
837- "x_times" : self .config .num_pipeline_repeats ,
838- "optimizer_dims_mapping" : None ,
839- },
840- )
841-
842- example_inputs = jax .lax .broadcast (example_inputs , [self .config .num_pipeline_repeats ])
843- example_segmentation = (
844- jax .lax .broadcast (example_segmentation , [self .config .num_pipeline_repeats ])
845- if example_segmentation is not None
846- else None
847- )
848- example_position = (
849- jax .lax .broadcast (example_position , [self .config .num_pipeline_repeats ])
850- if example_position is not None
851- else None
852- )
853- # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
854- # the full total_iterations.
855- example_inputs = self ._maybe_shard_with_logical (example_inputs , (None , None , None , None ))
856- stage_outputs = vmap_func (
857- self .layers , example_inputs , example_segmentation , example_position , deterministic , model_mode
858- )
859- if self .config .scan_layers :
860- stage_outputs = stage_outputs [0 ]
861-
862- # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
863- # which has shape [pipeline_microbatch_size, sequence, embed]
864- if self .config .num_pipeline_repeats > 1 :
865- stage_outputs = stage_outputs [0 ] # Remove extra dimension created for the circular vmap
866- broadcasted_stage_outpus = jax .lax .broadcast (
867- stage_outputs [0 ], [self .config .micro_batch_size_to_train_on // self .pipeline_microbatch_size ]
868- )
869-
870- return jnp .reshape (
871- broadcasted_stage_outpus ,
872- [self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ],
873- out_sharding = self .output_sharding ,
888+ return self ._run_initialization (
889+ example_inputs , example_segmentation , example_position , segment_idx , position_idx , deterministic , model_mode
874890 )
875891
876892 logical_partition_spec = pipeline_utils .get_logical_spec_repeats_removed (logical_partition_spec )
@@ -897,95 +913,39 @@ def run_iteration_scannable(model, loop_state):
897913 policy = self .get_pipeline_remat_policy (),
898914 )
899915
900- def run_one_repeat_scannable (model , loop_state ):
901- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
902- loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
903- )
904-
905- if model .config .scan_pipeline_iterations :
906- run_one_repeat_scanned_custom = pipeline_utils .create_scanned_function (
907- model = model ,
908- run_iteration_scannable = run_iteration_scannable ,
909- length = model .config .num_pipeline_microbatches ,
910- variable_axes = {
911- "summaries" : 0 ,
912- "aux_loss" : 0 ,
913- "intermediates" : 0 ,
914- "hyper_params" : 0 ,
915- },
916- split_rngs = {"random" : True },
917- deterministic = deterministic ,
918- model_mode = model_mode ,
919- logical_partition_spec = logical_partition_spec ,
920- )
921- loop_state = run_one_repeat_scanned_custom (loop_state , positions , segment_ids )
922- else :
923- for _ in range (model .config .num_pipeline_microbatches ):
924- loop_state , _ = run_iteration_scannable (model , loop_state )
925- return loop_state , None
926-
927- run_one_repeat_scannable = nn .remat (
928- run_one_repeat_scannable ,
929- prevent_cse = not self .config .scan_pipeline_iterations ,
930- policy = self .get_pipeline_remat_policy (),
916+ run_one_repeat_scannable = pipeline_utils .create_run_scannable (
917+ model = self ,
918+ run_iteration_scannable = run_iteration_scannable ,
919+ length = self .config .num_pipeline_microbatches ,
920+ deterministic = deterministic ,
921+ model_mode = model_mode ,
922+ logical_partition_spec = logical_partition_spec ,
923+ physical_partition_spec = physical_partition_spec ,
924+ positions = positions ,
925+ segment_ids = segment_ids ,
931926 )
932927
933- def run_bubbles_scannable (model , loop_state ):
934- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
935- loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
936- )
937-
938- if model .config .scan_pipeline_iterations :
939- run_bubbles_scanned_custom = pipeline_utils .create_scanned_function (
940- model = model ,
941- run_iteration_scannable = run_iteration_scannable ,
942- length = bubble_iterations ,
943- variable_axes = {
944- "summaries" : 0 ,
945- "aux_loss" : 0 ,
946- "intermediates" : 0 ,
947- "hyper_params" : 0 ,
948- },
949- split_rngs = {"random" : True },
950- deterministic = deterministic ,
951- model_mode = model_mode ,
952- logical_partition_spec = logical_partition_spec ,
953- )
954- loop_state = run_bubbles_scanned_custom (loop_state , positions , segment_ids )
955- else :
956- for _ in range (model .config .num_pipeline_microbatches ):
957- loop_state , _ = run_iteration_scannable (model , loop_state )
958- return loop_state , None
959-
960- run_bubbles_scannable = nn .remat (
961- run_bubbles_scannable ,
962- prevent_cse = not self .config .scan_pipeline_iterations ,
963- policy = self .get_pipeline_remat_policy (),
928+ run_bubbles_scannable = pipeline_utils .create_run_scannable (
929+ model = self ,
930+ run_iteration_scannable = run_iteration_scannable ,
931+ length = bubble_iterations ,
932+ deterministic = deterministic ,
933+ model_mode = model_mode ,
934+ logical_partition_spec = logical_partition_spec ,
935+ physical_partition_spec = physical_partition_spec ,
936+ positions = positions ,
937+ segment_ids = segment_ids ,
964938 )
965939
966940 def run_all_iterations (model , loop_state ):
967941 if self .config .scan_pipeline_repeats :
968- run_repeats_scanned = nn .scan (
969- run_one_repeat_scannable ,
970- variable_axes = {
971- "summaries" : 0 ,
972- "aux_loss" : 0 ,
973- "intermediates" : 0 ,
974- "hyper_params" : 0 ,
975- },
976- split_rngs = {"random" : True },
942+ run_repeats_scanned = pipeline_utils .create_run_repeats_scanned (
943+ run_scannable = run_one_repeat_scannable ,
977944 length = model .config .num_pipeline_repeats ,
978945 )
979946
980- run_bubbles_scanned = nn .scan (
981- run_bubbles_scannable ,
982- variable_axes = {
983- "summaries" : 0 ,
984- "aux_loss" : 0 ,
985- "intermediates" : 0 ,
986- "hyper_params" : 0 ,
987- },
988- split_rngs = {"random" : True },
947+ run_bubbles_scanned = pipeline_utils .create_run_repeats_scanned (
948+ run_scannable = run_bubbles_scannable ,
989949 length = 1 ,
990950 )
991951 loop_state , _ = run_repeats_scanned (model , loop_state )
0 commit comments