@@ -550,24 +550,20 @@ def find_fsdp(pspec):
550550
551551 return jax .tree .map (find_fsdp , physical_partition_spec )
552552
553- def bsw_all_gather_over_fsdp (self , weights , bsw , physical_partition_spec , loop_iteration ):
554- """All gather bsw over fsdp mesh axis using shardmap."""
555- bsw_pps = self ._generate_bsw_pps_from_pps (physical_partition_spec )
556- repeat_weights_pps = jax .tree .map (lambda p : P (* p [1 :]), physical_partition_spec )
557- fsdp_idx = self .get_fsdp_index_pytree (physical_partition_spec )
558-
559- _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration + 1 )
553+ def from_all_variables_to_repeat_weights (self , loop_iteration , physical_partition_spec ):
554+ """Generate one single repeat weight from all variables."""
555+ _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
560556
561557 def gather_weights_for_stages_in (w , spec ):
562558 return self .vmap_parallel_gather (
563559 w , repeat_ids = repeat_ids , repeat_dim_in_weights = 0 , stages_dim_in_weights = 1 , physical_partition_spec = spec
564560 )
565561
562+ weights = self ._remove_logically_partition (self .layers .variables )
566563 if physical_partition_spec is None :
567564 repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights )
568565 else :
569566 repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
570-
571567 circular_metadata_params = {
572568 nn .PARTITION_NAME : "circular_repeats" ,
573569 "sub_weight_split_dims_mapping" : (None ,),
@@ -576,25 +572,36 @@ def gather_weights_for_stages_in(w, spec):
576572 "optimizer_dims_mapping" : None ,
577573 }
578574 repeat_weights = meta .remove_axis (repeat_weights , 0 , circular_metadata_params )
575+ return repeat_weights
576+
577+ def from_all_variables_to_bsw (self , loop_iteration , physical_partition_spec ):
578+ """All gather one branch of bsw using shardmap."""
579+ repeat_weights = self .from_all_variables_to_repeat_weights (loop_iteration , physical_partition_spec )
580+ bsw_pps = self ._generate_bsw_pps_from_pps (physical_partition_spec )
581+ repeat_weights_pps = jax .tree .map (lambda p : P (* p [1 :]), physical_partition_spec )
582+ fsdp_idx = self .get_fsdp_index_pytree (physical_partition_spec )
579583
580584 @jax .shard_map (
581585 mesh = self .mesh ,
582- in_specs = (repeat_weights_pps , ( bsw_pps , bsw_pps ), None ),
583- out_specs = ( bsw_pps , bsw_pps ) ,
586+ in_specs = (repeat_weights_pps , None ),
587+ out_specs = bsw_pps ,
584588 check_vma = True ,
585589 )
586- def _all_gather_inner (sharded_weights , cur_bsw , fsdp_idx ):
590+ def _all_gather_inner (sharded_weights , fsdp_idx ):
587591 def _all_gather_invariant (x , i ):
588592 if i >= 0 :
589593 return all_gather_invariant (x , axis_name = "fsdp" , axis = i - 1 , tiled = True )
590594 return x
591595
592- new_variables = jax .tree .map (_all_gather_invariant , sharded_weights , fsdp_idx )
593- new_variables = jax .ad_checkpoint .checkpoint_name (new_variables , "bsw_gathered_weights" )
596+ return jax .tree .map (_all_gather_invariant , sharded_weights , fsdp_idx )
594597
595- return jax . ad_checkpoint . checkpoint_name (( cur_bsw [ 1 ], new_variables ), "bsw" )
598+ return _all_gather_inner ( repeat_weights , fsdp_idx )
596599
597- return _all_gather_inner (repeat_weights , bsw , fsdp_idx )
600+ def bsw_all_gather_over_fsdp (self , physical_partition_spec , loop_iteration ):
601+ """All gather all bsw over fsdp mesh axis using shardmap."""
602+ bsw_0 = self .from_all_variables_to_bsw (loop_iteration , physical_partition_spec )
603+ bsw_1 = self .from_all_variables_to_bsw (loop_iteration + 1 , physical_partition_spec )
604+ return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
598605
599606 def get_vmap_func_for_init (self ):
600607 """This vmap func is used to initialize the weights only on init."""
@@ -985,15 +992,12 @@ def run_iteration_scannable(model, loop_state):
985992 if self .config .set_remat_policy_on_pipeline_iterations :
986993 run_iteration_scannable = nn .remat (
987994 run_iteration_scannable ,
988- prevent_cse = True , # not self.config.scan_pipeline_iterations,
995+ prevent_cse = not self .config .scan_pipeline_iterations ,
989996 policy = self .get_pipeline_remat_policy (),
990997 )
991998
992999 def run_one_repeat_scannable (model , loop_state ):
993- weights = model ._remove_logically_partition (model .layers .variables ) # pylint: disable=protected-access
994- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
995- weights , loop_state ["bsw" ], physical_partition_spec , loop_state ["loop_iteration" ]
996- )
1000+ loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (physical_partition_spec , loop_state ["loop_iteration" ])
9971001
9981002 if model .config .scan_pipeline_iterations :
9991003 run_one_repeat_scanned = nn .scan (
@@ -1018,7 +1022,7 @@ def run_one_repeat_scannable(model, loop_state):
10181022
10191023 run_one_repeat_scannable = nn .remat (
10201024 run_one_repeat_scannable ,
1021- prevent_cse = True ,
1025+ prevent_cse = not self . config . scan_pipeline_iterations ,
10221026 policy = self .get_pipeline_remat_policy (),
10231027 )
10241028
@@ -1052,7 +1056,7 @@ def run_all_iterations(model, loop_state):
10521056 length = bubble_iterations ,
10531057 )
10541058 loop_state , _ = run_repeats_scanned (model , loop_state )
1055- loop_state ["bsw" ] = ( loop_state [ "bsw" ][ 1 ], jax . tree . map ( jnp . zeros_like , loop_state ["bsw" ][ 1 ]) )
1059+ loop_state ["bsw" ] = model . bsw_all_gather_over_fsdp ( physical_partition_spec , loop_state ["loop_iteration" ] )
10561060 loop_state , _ = run_bubbles_scanned (model , loop_state )
10571061 else :
10581062 for _ in range (model .config .num_pipeline_repeats ): # remat and scan outer loop
0 commit comments