Skip to content

Commit 04ffb00

Browse files
committed
clean version fsdp+pp bug free
1 parent 51ddc35 commit 04ffb00

1 file changed

Lines changed: 2 additions & 64 deletions

File tree

src/MaxText/layers/pipeline.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -275,31 +275,6 @@ def select_state_or_input(first_stage_in, shift):
275275
stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical)
276276
return stages_in
277277

278-
def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False):
279-
"""Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
280-
the specified dimension."""
281-
# placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
282-
# if physical_partition_spec is None:
283-
# dims_mapping = [placeholder] * x.ndim
284-
# else:
285-
# physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
286-
# dims_mapping = list(physical_partition_spec)
287-
# # If not a stage weight, we handle the repeat dimension offset
288-
# if not is_stage_weight:
289-
# dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
290-
# dims_mapping[dim] = "stage"
291-
# dims_mapping = tuple(dims_mapping)
292-
# # We add reduced rule only when pspec is given for a stage weight
293-
# if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
294-
# batch_mesh_axis = ["data", "fsdp"]
295-
# reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
296-
# pspec = P(*dims_mapping, reduced=set(reduced_mark))
297-
# else:
298-
# pspec = P(*dims_mapping)
299-
# sharding = jax.sharding.NamedSharding(self.mesh, pspec)
300-
# return self._maybe_shard_with_name(x, sharding)
301-
return x
302-
303278
def get_microbatch_and_repeat_ids(self, loop_iteration):
304279
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
305280
non-circular"""
@@ -309,14 +284,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
309284
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
310285
return microbatch_ids, repeat_ids
311286

312-
def get_microbatch_and_repeat_ids_for_bsw(self, loop_iteration):
313-
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
314-
non-circular"""
315-
raw_processed = loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages)
316-
repeat_ids = raw_processed // self.config.num_pipeline_microbatches
317-
microbatch_ids = jnp.maximum(raw_processed, 0) % self.config.num_pipeline_microbatches
318-
return microbatch_ids, repeat_ids
319-
320287
def vmap_parallel_gather(
321288
self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights
322289
):
@@ -339,18 +306,9 @@ def _gather_one(x, repeat_id):
339306
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights)
340307

341308
gathered_weights_stage_dim = 0
342-
repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None)
343-
# num_repeats x num_stages x *param_dim
344-
weights = self.shard_dim_by_stages(
345-
weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False
346-
)
347309
stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(
348310
weights, repeat_ids
349311
)
350-
# num_stages x *param_dim
351-
stage_weights = self.shard_dim_by_stages(
352-
stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True
353-
)
354312
return stage_weights
355313

356314
def vmap_gather(self, xs, ids, ids_dim):
@@ -374,9 +332,7 @@ def _gather_one(x, i):
374332
replicated_sharding = NamedSharding(self.mesh, P())
375333
return x.at[idx].get(out_sharding=replicated_sharding)
376334

377-
ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None)
378-
outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
379-
return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None)
335+
return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
380336

381337
def get_new_loop_state(self, output, loop_state):
382338
"""
@@ -533,17 +489,6 @@ def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_s
533489
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
534490
target_repeat_id = repeat_ids[0]
535491

536-
# path = ("params", "mlp", "wi_0", "kernel")
537-
# path = ("params", "weights")
538-
539-
# jax.debug.print(
540-
# "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | "
541-
# "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}",
542-
# iter=loop_iteration, target=target_repeat_id, rids=repeat_ids,
543-
# bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)),
544-
# bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)),
545-
# )
546-
547492
@jax.shard_map(
548493
mesh=self.mesh,
549494
in_specs=((bsw_pps, bsw_pps), P("stage")),
@@ -556,14 +501,7 @@ def select_weights_from_bsw(bsw, repeat_id):
556501
bsw[0],
557502
bsw[1],
558503
)
559-
# jax.debug.print(
560-
# "Iteration: {iter} | "
561-
# "Selected weights mean for Stage {s} with repeat id {i}: {m}",
562-
# iter=loop_iteration,
563-
# s=jax.lax.axis_index("stage"),
564-
# m=maxtext_utils.get_nested_value(weights, path).mean(),
565-
# i=repeat_id[0],
566-
# )
504+
567505
return weights
568506

569507
weights = select_weights_from_bsw(bsw, repeat_ids)

0 commit comments

Comments
 (0)