Skip to content

Commit d70484b

Browse files
committed
refactor pr
1 parent 3ecc35b commit d70484b

3 files changed

Lines changed: 204 additions & 125 deletions

File tree

src/maxtext/layers/pipeline.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,31 @@ def _run_weight_initialization(
328328
out_sharding=self.output_sharding,
329329
)
330330

331+
@staticmethod
332+
def _remove_fsdp_from_physical_partition_spec(pps):
333+
"""Removes 'fsdp' and 'fsdp_transpose' from physical partition spec."""
334+
if isinstance(pps, P):
335+
new_spec = []
336+
# Iterate through each axis in the original PartitionSpec.
337+
for axis in pps:
338+
if axis is None:
339+
new_spec.append(None)
340+
elif isinstance(axis, str):
341+
# If the axis is 'fsdp', replace it with None to signify replication.
342+
if axis not in ("fsdp", "fsdp_transpose"):
343+
new_spec.append(axis)
344+
else:
345+
new_spec.append(None)
346+
elif isinstance(axis, (list, tuple)):
347+
# If the axis is a collection, filter out 'fsdp'.
348+
new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")]
349+
new_spec.append(tuple(new_axis))
350+
else:
351+
raise ValueError(f"Unsupported_axis_type: {type(axis)}")
352+
# Return a new sharding object with the modified spec.
353+
return P(*new_spec)
354+
return pps
355+
331356

332357
class Pipeline(PipelineBase):
333358
"""Original Pipeline implementation."""
@@ -754,31 +779,6 @@ def _remove_logically_partition_leaf(v):
754779

755780
return jax.tree.map(_remove_logically_partition_leaf, weights, is_leaf=lambda v: isinstance(v, LogicallyPartitioned))
756781

757-
@staticmethod
758-
def _remove_fsdp_from_physical_partition_spec(pps):
759-
"""Removes 'fsdp' and 'fsdp_transpose' from physical partition spec."""
760-
if isinstance(pps, P):
761-
new_spec = []
762-
# Iterate through each axis in the original PartitionSpec.
763-
for axis in pps:
764-
if axis is None:
765-
new_spec.append(None)
766-
elif isinstance(axis, str):
767-
# If the axis is 'fsdp', replace it with None to signify replication.
768-
if axis not in ("fsdp", "fsdp_transpose"):
769-
new_spec.append(axis)
770-
else:
771-
new_spec.append(None)
772-
elif isinstance(axis, (list, tuple)):
773-
# If the axis is a collection, filter out 'fsdp'.
774-
new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")]
775-
new_spec.append(tuple(new_axis))
776-
else:
777-
raise ValueError(f"Unsupported_axis_type: {type(axis)}")
778-
# Return a new sharding object with the modified spec.
779-
return P(*new_spec)
780-
return pps
781-
782782
def all_gather_over_fsdp(self, variables, logical_partition_spec):
783783
"""Gathers FSDP partitioned variables to reconstruct them fully."""
784784
physical_partition_spec = logical_to_mesh(
@@ -1107,7 +1107,7 @@ def get_current_stage_weights(
11071107

11081108
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None):
11091109
"""Retrieves current weights out of the sliding buffer window (bsw)."""
1110-
bsw_pps = jax.tree.map(pipeline_utils.remove_fsdp_from_physical_partition_spec, physical_partition_spec)
1110+
bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec)
11111111
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
11121112
target_repeat_id = repeat_ids[0]
11131113

@@ -1159,12 +1159,12 @@ def from_repeat_weights_to_bsw(
11591159
axes_to_gather=("fsdp", "fsdp_transpose", "expert"), # three major FSDP-like axes
11601160
):
11611161
"""Generates the buffer sliding window (bsw) from the gathered repeat weights."""
1162-
bsw_pps = pipeline_utils.generate_bsw_pps_from_pps(physical_partition_spec)
1162+
bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec)
11631163
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
11641164

11651165
# Dynamically gather the index pytrees for all specified axes
11661166
axis_indices_dict = {
1167-
axis: pipeline_utils.get_fsdp_index_pytree(physical_partition_spec, axis) for axis in axes_to_gather
1167+
axis: pipeline_utils.get_mesh_axis_dim_indices(physical_partition_spec, axis) for axis in axes_to_gather
11681168
}
11691169

11701170
axis_names = list(axis_indices_dict.keys())
@@ -1299,7 +1299,7 @@ def __call__(
12991299
example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
13001300
)
13011301

1302-
logical_partition_spec = pipeline_utils.get_logical_spec_repeats_removed(logical_partition_spec)
1302+
logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec)
13031303

13041304
def run_iteration_scannable(model, loop_state):
13051305
return (
@@ -1318,7 +1318,7 @@ def run_iteration_scannable(model, loop_state):
13181318

13191319
# base scannable function used twice for real and bubble runs
13201320
base_scannable = functools.partial(
1321-
pipeline_utils.create_run_scannable,
1321+
pipeline_utils.create_rematerialized_pipeline_stage,
13221322
model=self,
13231323
run_iteration_scannable=run_iteration_scannable,
13241324
deterministic=deterministic,
@@ -1334,10 +1334,10 @@ def run_iteration_scannable(model, loop_state):
13341334

13351335
def run_all_iterations(model, loop_state):
13361336
if self.config.scan_pipeline_repeats:
1337-
run_repeats_scanned = pipeline_utils.create_run_repeats_scanned(
1338-
run_scannable=run_one_repeat_scannable, length=model.config.num_pipeline_repeats
1337+
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1338+
pipeline_stage_fn=run_one_repeat_scannable, length=model.config.num_pipeline_repeats
13391339
)
1340-
run_bubbles_scanned = pipeline_utils.create_run_repeats_scanned(run_scannable=run_bubbles_scannable, length=1)
1340+
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
13411341
loop_state, _ = run_repeats_scanned(model, loop_state)
13421342
loop_state, _ = run_bubbles_scanned(model, loop_state)
13431343
else:

0 commit comments

Comments
 (0)