1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """ Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """
15+ """Pipeline layer wrapping a decoder layer(s). Supports circular pipelining"""
1616
17- import functools
17+ # import functools
1818from typing import Any
19+ import functools
1920
2021import numpy as np
2122
@@ -225,6 +226,7 @@ def _init_bsw_from_weights(variables):
225226 "loop_iteration" : 0 ,
226227 "prev_outputs" : prev_outputs ,
227228 "bsw" : bsw ,
229+ "weights" : self .layers .variables ,
228230 }
229231 return init_loop_state
230232
@@ -455,6 +457,7 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx):
455457 "loop_iteration" : loop_iteration + 1 ,
456458 "prev_outputs" : new_prev_outputs ,
457459 "bsw" : loop_state ["bsw" ], # bsw is updated outside of this inner loop, only once per outer loop iteration
460+ "weights" : loop_state ["weights" ], # Pass weights through
458461 }
459462 return new_loop_state
460463
@@ -469,7 +472,9 @@ def permute_output_micro_per_stage_dim(self, output):
469472 output = output [:, permutation ]
470473 return output
471474
472- def get_current_stage_weights (self , pipeline_weights , bsw , loop_iteration , physical_partition_spec = None ):
475+ def get_current_stage_weights (
476+ self , pipeline_weights , bsw , loop_iteration , physical_partition_spec = None , is_initializing = None
477+ ):
473478 """
474479 Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g.
475480 {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc.
@@ -479,11 +484,11 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi
479484 pipeline_weights = self ._remove_logically_partition (pipeline_weights )
480485 if self .config .num_pipeline_repeats > 1 :
481486 pipeline_weights = self .get_current_weights_from_bsw (
482- bsw , loop_iteration , physical_partition_spec = physical_partition_spec
487+ bsw , loop_iteration , physical_partition_spec = physical_partition_spec , is_initializing = is_initializing
483488 )
484489 return pipeline_weights
485490
486- def get_current_weights_from_bsw (self , bsw , loop_iteration , physical_partition_spec ):
491+ def get_current_weights_from_bsw (self , bsw , loop_iteration , physical_partition_spec , is_initializing = None ):
487492 """Collect and gather weights from given bsw (buffer sliding window)"""
488493 bsw_pps = jax .tree .map (self ._remove_fsdp_from_physical_partition_spec , physical_partition_spec )
489494 _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
@@ -506,10 +511,13 @@ def select_weights_from_bsw(bsw, repeat_id):
506511
507512 weights = select_weights_from_bsw (bsw , repeat_ids )
508513
514+ if is_initializing is None :
515+ is_initializing = self .is_initializing ()
516+
509517 circular_metadata_params = {
510518 nn .PARTITION_NAME : "circular_repeats" ,
511519 "sub_weight_split_dims_mapping" : (None ,),
512- "is_initializing" : self . is_initializing () ,
520+ "is_initializing" : is_initializing ,
513521 "x_times" : self .config .num_pipeline_repeats ,
514522 "optimizer_dims_mapping" : None ,
515523 }
@@ -550,7 +558,7 @@ def find_fsdp(pspec):
550558
551559 return jax .tree .map (find_fsdp , physical_partition_spec )
552560
553- def from_all_variables_to_repeat_weights (self , loop_iteration , physical_partition_spec ):
561+ def from_all_variables_to_repeat_weights (self , weights , loop_iteration , physical_partition_spec ):
554562 """Generate one single repeat weight from all variables."""
555563 _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
556564
@@ -559,24 +567,24 @@ def gather_weights_for_stages_in(w, spec):
559567 w , repeat_ids = repeat_ids , repeat_dim_in_weights = 0 , stages_dim_in_weights = 1 , physical_partition_spec = spec
560568 )
561569
562- weights = self ._remove_logically_partition (self . layers . variables )
570+ weights = self ._remove_logically_partition (weights )
563571 if physical_partition_spec is None :
564- repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights )
572+ weights = jax .tree .map (gather_weights_for_stages_in , weights )
565573 else :
566- repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
574+ weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
567575 circular_metadata_params = {
568576 nn .PARTITION_NAME : "circular_repeats" ,
569577 "sub_weight_split_dims_mapping" : (None ,),
570578 "is_initializing" : self .is_initializing (),
571579 "x_times" : self .config .num_pipeline_repeats ,
572580 "optimizer_dims_mapping" : None ,
573581 }
574- repeat_weights = meta .remove_axis (repeat_weights , 0 , circular_metadata_params )
582+ repeat_weights = meta .remove_axis (weights , 0 , circular_metadata_params )
575583 return repeat_weights
576584
577- def from_all_variables_to_bsw (self , loop_iteration , physical_partition_spec ):
585+ def from_all_variables_to_bsw (self , weights , loop_iteration , physical_partition_spec ):
578586 """All gather one branch of bsw using shardmap."""
579- repeat_weights = self .from_all_variables_to_repeat_weights (loop_iteration , physical_partition_spec )
587+ repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration , physical_partition_spec )
580588 bsw_pps = self ._generate_bsw_pps_from_pps (physical_partition_spec )
581589 repeat_weights_pps = jax .tree .map (lambda p : P (* p [1 :]), physical_partition_spec )
582590 fsdp_idx = self .get_fsdp_index_pytree (physical_partition_spec )
@@ -597,10 +605,10 @@ def _all_gather_invariant(x, i):
597605
598606 return _all_gather_inner (repeat_weights , fsdp_idx )
599607
600- def bsw_all_gather_over_fsdp (self , physical_partition_spec , loop_iteration ):
608+ def bsw_all_gather_over_fsdp (self , weights , physical_partition_spec , loop_iteration ):
601609 """All gather all bsw over fsdp mesh axis using shardmap."""
602- bsw_0 = self .from_all_variables_to_bsw (loop_iteration , physical_partition_spec )
603- bsw_1 = self .from_all_variables_to_bsw (loop_iteration + 1 , physical_partition_spec )
610+ bsw_0 = self .from_all_variables_to_bsw (weights , loop_iteration , physical_partition_spec )
611+ bsw_1 = self .from_all_variables_to_bsw (weights , loop_iteration + 1 , physical_partition_spec )
604612 return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
605613
606614 def get_vmap_func_for_init (self ):
@@ -666,20 +674,22 @@ def func_to_vmap(
666674 def run_one_iteration (
667675 self ,
668676 loop_state ,
669- pipeline_weights ,
670677 positions ,
671678 segment_ids ,
672679 deterministic ,
673680 model_mode ,
674681 decoder_layer_instance ,
675682 logical_partition_spec ,
683+ vmap_func = None ,
684+ is_initializing = None ,
676685 ):
677686 """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
678687 and update the loop state."""
679688 state_io = loop_state ["state_io" ]
680689 shift = loop_state ["shift" ]
681690 circ_storage = loop_state ["circ_storage" ]
682691 loop_iteration = loop_state ["loop_iteration" ]
692+ pipeline_weights = loop_state ["weights" ]
683693
684694 microbatch_ids , _ = self .get_microbatch_and_repeat_ids (loop_iteration )
685695
@@ -693,49 +703,15 @@ def run_one_iteration(
693703 stages_positions = self .vmap_gather (positions , microbatch_ids , 0 ) if positions is not None else None
694704 stages_segment_ids = self .vmap_gather (segment_ids , microbatch_ids , 0 ) if segment_ids is not None else None
695705
696- vmap_func = self .get_main_vmap_func_for_iterations ()
697-
698- if self .config .num_pipeline_repeats > 1 :
699- _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
700-
701- def prepare_vars_for_main_vmap (weights , physical_partition_spec = None ):
702-
703- circular_metadata_params = {
704- nn .PARTITION_NAME : "circular_repeats" ,
705- "sub_weight_split_dims_mapping" : (None ,),
706- "is_initializing" : self .is_initializing (),
707- "x_times" : self .config .num_pipeline_repeats ,
708- "optimizer_dims_mapping" : None ,
709- }
710- weights = meta .remove_axis (
711- weights , 0 , circular_metadata_params
712- ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one
713- # circular entry per stage.
714- weights = self ._remove_logically_partition (weights )
715-
716- def gather_weights_for_stages_in (w , spec = None ):
717- return self .vmap_parallel_gather (
718- w , repeat_ids = repeat_ids , repeat_dim_in_weights = 0 , stages_dim_in_weights = 1 , physical_partition_spec = spec
719- )
720-
721- if physical_partition_spec is None :
722- weights = jax .tree .map (gather_weights_for_stages_in , weights )
723- else :
724- weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
725- return weights
726-
727- prepare_vars_for_main_vmap_partial = functools .partial (
728- prepare_vars_for_main_vmap , physical_partition_spec = physical_partition_spec
729- )
730- vmap_func = nn .map_variables (
731- vmap_func ,
732- mapped_collections = ["params" , "_overwrite_with_gradient" , "non_trainable" , "summaries" , "intermediates" ],
733- mutable = True ,
734- trans_in_fn = prepare_vars_for_main_vmap_partial ,
735- )
706+ if vmap_func is None :
707+ vmap_func = self .get_main_vmap_func_for_iterations ()
736708
737709 stage_weights = self .get_current_stage_weights (
738- pipeline_weights , loop_state ["bsw" ], loop_iteration , physical_partition_spec = physical_partition_spec
710+ pipeline_weights ,
711+ loop_state ["bsw" ],
712+ loop_iteration ,
713+ physical_partition_spec = physical_partition_spec ,
714+ is_initializing = is_initializing ,
739715 )
740716
741717 stages_output = vmap_func (
@@ -978,7 +954,6 @@ def run_iteration_scannable(model, loop_state):
978954 return (
979955 model .run_one_iteration (
980956 loop_state ,
981- model .layers .variables ,
982957 positions ,
983958 segment_ids ,
984959 deterministic ,
@@ -997,7 +972,9 @@ def run_iteration_scannable(model, loop_state):
997972 )
998973
999974 def run_one_repeat_scannable (model , loop_state ):
1000- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (physical_partition_spec , loop_state ["loop_iteration" ])
975+ loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
976+ loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
977+ )
1001978
1002979 if model .config .scan_pipeline_iterations :
1003980 run_one_repeat_scanned = nn .scan (
@@ -1014,7 +991,85 @@ def run_one_repeat_scannable(model, loop_state):
1014991 split_rngs = {"random" : True },
1015992 length = model .config .num_pipeline_microbatches ,
1016993 )
1017- loop_state , _ = run_one_repeat_scanned (model , loop_state )
994+
995+ @functools .partial (jax .custom_vjp )
996+ def run_one_repeat_scanned_custom (loop_state , positions , segment_ids ):
997+ final_state , _ = run_one_repeat_scanned (model , loop_state )
998+ return final_state
999+
1000+ def run_one_repeat_scanned_custom_fwd (loop_state , positions , segment_ids ):
1001+ final_state , _ = run_one_repeat_scanned (model , loop_state )
1002+ # We return loop_state as residual. model is passed to bwd as arg.
1003+ return final_state , (
1004+ loop_state ,
1005+ positions ,
1006+ segment_ids ,
1007+ )
1008+
1009+ def run_one_repeat_scanned_custom_bwd (residuals , g_final_state ):
1010+ init_loop_state , positions , segment_ids = residuals
1011+
1012+ # Re-run forward pass to get saved states (checkpointing)
1013+ def scan_body_fwd (carry , _ ):
1014+ new_state = model .run_one_iteration (
1015+ carry ,
1016+ positions ,
1017+ segment_ids ,
1018+ deterministic ,
1019+ model_mode ,
1020+ model .layers ,
1021+ logical_partition_spec = logical_partition_spec ,
1022+ )
1023+ # Return lightweight state for saving (exclude bsw/weights)
1024+ saved = {k : v for k , v in carry .items () if k not in ["bsw" , "weights" ]}
1025+ return new_state , saved
1026+
1027+ _ , saved_states = jax .lax .scan (
1028+ scan_body_fwd ,
1029+ init_loop_state ,
1030+ None ,
1031+ length = model .config .num_pipeline_microbatches ,
1032+ )
1033+
1034+ # Backward scan to accumulate gradients
1035+ def scan_body_bwd (carry , saved_slice ):
1036+ d_next_state = carry
1037+
1038+ # Reconstruct current loop_state (input to step)
1039+ curr_loop_state = {
1040+ ** saved_slice ,
1041+ "bsw" : init_loop_state ["bsw" ],
1042+ "weights" : init_loop_state ["weights" ],
1043+ }
1044+
1045+ # Define function to differentiate w.r.t loop_state
1046+ def step_fn (s ):
1047+ out = model .run_one_iteration (
1048+ s ,
1049+ positions ,
1050+ segment_ids ,
1051+ deterministic ,
1052+ model_mode ,
1053+ model .layers ,
1054+ logical_partition_spec = logical_partition_spec ,
1055+ )
1056+ return out
1057+
1058+ _ , vjp_fun = jax .vjp (step_fn , curr_loop_state )
1059+
1060+ # Backprop d_next_state
1061+ (d_curr_state ,) = vjp_fun (d_next_state )
1062+
1063+ return d_curr_state , None
1064+
1065+ # Run backward scan
1066+ d_init_state , _ = jax .lax .scan (scan_body_bwd , g_final_state , saved_states , reverse = True )
1067+
1068+ return (d_init_state , None , None )
1069+
1070+ run_one_repeat_scanned_custom .defvjp (run_one_repeat_scanned_custom_fwd , run_one_repeat_scanned_custom_bwd )
1071+
1072+ loop_state = run_one_repeat_scanned_custom (loop_state , positions , segment_ids )
10181073 else :
10191074 for _ in range (model .config .num_pipeline_microbatches ):
10201075 loop_state , _ = run_iteration_scannable (model , loop_state )
@@ -1056,7 +1111,9 @@ def run_all_iterations(model, loop_state):
10561111 length = bubble_iterations ,
10571112 )
10581113 loop_state , _ = run_repeats_scanned (model , loop_state )
1059- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (physical_partition_spec , loop_state ["loop_iteration" ])
1114+ loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
1115+ loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
1116+ )
10601117 loop_state , _ = run_bubbles_scanned (model , loop_state )
10611118 else :
10621119 for _ in range (model .config .num_pipeline_repeats ): # remat and scan outer loop
@@ -1068,14 +1125,11 @@ def run_all_iterations(model, loop_state):
10681125 # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized.
10691126 # if self.config.scan_pipeline_iterations:
10701127 variable_carry = []
1071- variable_broadcast = [
1072- "params" ,
1073- "_overwrite_with_gradient" ,
1074- ] # All loop iterations need the weights for the full pipeline.
1075- if self .is_mutable_collection ("non_trainable" ):
1076- variable_carry .append ("non_trainable" )
1077- else :
1078- variable_broadcast .append ("non_trainable" )
1128+ variable_broadcast = [] # All loop iterations need the weights for the full pipeline.
1129+ # if self.is_mutable_collection("non_trainable"):
1130+ # variable_carry.append("non_trainable")
1131+ # else:
1132+ # variable_broadcast.append("non_trainable")
10791133
10801134 loop_state = run_all_iterations (self , loop_state )
10811135
0 commit comments