@@ -328,6 +328,31 @@ def _run_weight_initialization(
328328 out_sharding = self .output_sharding ,
329329 )
330330
331+ @staticmethod
332+ def _remove_fsdp_from_physical_partition_spec (pps ):
333+ """Removes 'fsdp' and 'fsdp_transpose' from physical partition spec."""
334+ if isinstance (pps , P ):
335+ new_spec = []
336+ # Iterate through each axis in the original PartitionSpec.
337+ for axis in pps :
338+ if axis is None :
339+ new_spec .append (None )
340+ elif isinstance (axis , str ):
341+ # If the axis is 'fsdp', replace it with None to signify replication.
342+ if axis not in ("fsdp" , "fsdp_transpose" ):
343+ new_spec .append (axis )
344+ else :
345+ new_spec .append (None )
346+ elif isinstance (axis , (list , tuple )):
347+ # If the axis is a collection, filter out 'fsdp'.
348+ new_axis = [a for a in axis if a not in ("fsdp" , "fsdp_transpose" )]
349+ new_spec .append (tuple (new_axis ))
350+ else :
351+ raise ValueError (f"Unsupported_axis_type: { type (axis )} " )
352+ # Return a new sharding object with the modified spec.
353+ return P (* new_spec )
354+ return pps
355+
331356
332357class Pipeline (PipelineBase ):
333358 """Original Pipeline implementation."""
@@ -754,31 +779,6 @@ def _remove_logically_partition_leaf(v):
754779
755780 return jax .tree .map (_remove_logically_partition_leaf , weights , is_leaf = lambda v : isinstance (v , LogicallyPartitioned ))
756781
757- @staticmethod
758- def _remove_fsdp_from_physical_partition_spec (pps ):
759- """Removes 'fsdp' and 'fsdp_transpose' from physical partition spec."""
760- if isinstance (pps , P ):
761- new_spec = []
762- # Iterate through each axis in the original PartitionSpec.
763- for axis in pps :
764- if axis is None :
765- new_spec .append (None )
766- elif isinstance (axis , str ):
767- # If the axis is 'fsdp', replace it with None to signify replication.
768- if axis not in ("fsdp" , "fsdp_transpose" ):
769- new_spec .append (axis )
770- else :
771- new_spec .append (None )
772- elif isinstance (axis , (list , tuple )):
773- # If the axis is a collection, filter out 'fsdp'.
774- new_axis = [a for a in axis if a not in ("fsdp" , "fsdp_transpose" )]
775- new_spec .append (tuple (new_axis ))
776- else :
777- raise ValueError (f"Unsupported_axis_type: { type (axis )} " )
778- # Return a new sharding object with the modified spec.
779- return P (* new_spec )
780- return pps
781-
782782 def all_gather_over_fsdp (self , variables , logical_partition_spec ):
783783 """Gathers FSDP partitioned variables to reconstruct them fully."""
784784 physical_partition_spec = logical_to_mesh (
@@ -1107,7 +1107,7 @@ def get_current_stage_weights(
11071107
11081108 def get_current_weights_from_bsw (self , bsw , loop_iteration , physical_partition_spec , is_initializing = None ):
11091109 """Retrieves current weights out of the sliding buffer window (bsw)."""
1110- bsw_pps = jax .tree .map (pipeline_utils . remove_fsdp_from_physical_partition_spec , physical_partition_spec )
1110+ bsw_pps = jax .tree .map (self . _remove_fsdp_from_physical_partition_spec , physical_partition_spec )
11111111 _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
11121112 target_repeat_id = repeat_ids [0 ]
11131113
@@ -1159,12 +1159,12 @@ def from_repeat_weights_to_bsw(
11591159 axes_to_gather = ("fsdp" , "fsdp_transpose" , "expert" ), # three major FSDP-like axes
11601160 ):
11611161 """Generates the buffer sliding window (bsw) from the gathered repeat weights."""
1162- bsw_pps = pipeline_utils .generate_bsw_pps_from_pps (physical_partition_spec )
1162+ bsw_pps = pipeline_utils .derive_stage_weight_partition_specs (physical_partition_spec )
11631163 repeat_weights_pps = jax .tree .map (lambda p : P (* p [1 :]), physical_partition_spec )
11641164
11651165 # Dynamically gather the index pytrees for all specified axes
11661166 axis_indices_dict = {
1167- axis : pipeline_utils .get_fsdp_index_pytree (physical_partition_spec , axis ) for axis in axes_to_gather
1167+ axis : pipeline_utils .get_mesh_axis_dim_indices (physical_partition_spec , axis ) for axis in axes_to_gather
11681168 }
11691169
11701170 axis_names = list (axis_indices_dict .keys ())
@@ -1299,7 +1299,7 @@ def __call__(
12991299 example_inputs , example_segmentation , example_position , segment_idx , position_idx , deterministic , model_mode
13001300 )
13011301
1302- logical_partition_spec = pipeline_utils .get_logical_spec_repeats_removed (logical_partition_spec )
1302+ logical_partition_spec = pipeline_utils .strip_pipeline_repeat_logical_axis (logical_partition_spec )
13031303
13041304 def run_iteration_scannable (model , loop_state ):
13051305 return (
@@ -1318,7 +1318,7 @@ def run_iteration_scannable(model, loop_state):
13181318
13191319 # base scannable function used twice for real and bubble runs
13201320 base_scannable = functools .partial (
1321- pipeline_utils .create_run_scannable ,
1321+ pipeline_utils .create_rematerialized_pipeline_stage ,
13221322 model = self ,
13231323 run_iteration_scannable = run_iteration_scannable ,
13241324 deterministic = deterministic ,
@@ -1334,10 +1334,10 @@ def run_iteration_scannable(model, loop_state):
13341334
13351335 def run_all_iterations (model , loop_state ):
13361336 if self .config .scan_pipeline_repeats :
1337- run_repeats_scanned = pipeline_utils .create_run_repeats_scanned (
1338- run_scannable = run_one_repeat_scannable , length = model .config .num_pipeline_repeats
1337+ run_repeats_scanned = pipeline_utils .create_flax_pipeline_scan (
1338+ pipeline_stage_fn = run_one_repeat_scannable , length = model .config .num_pipeline_repeats
13391339 )
1340- run_bubbles_scanned = pipeline_utils .create_run_repeats_scanned ( run_scannable = run_bubbles_scannable , length = 1 )
1340+ run_bubbles_scanned = pipeline_utils .create_flax_pipeline_scan ( pipeline_stage_fn = run_bubbles_scannable , length = 1 )
13411341 loop_state , _ = run_repeats_scanned (model , loop_state )
13421342 loop_state , _ = run_bubbles_scanned (model , loop_state )
13431343 else :
0 commit comments