@@ -275,31 +275,6 @@ def select_state_or_input(first_stage_in, shift):
275275 stages_in = self ._maybe_shard_with_logical (stages_in , self .stages_in_logical )
276276 return stages_in
277277
278- def shard_dim_by_stages (self , x , dim : int , physical_partition_spec : P | None , is_stage_weight : bool = False ):
279- """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
280- the specified dimension."""
281- # placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
282- # if physical_partition_spec is None:
283- # dims_mapping = [placeholder] * x.ndim
284- # else:
285- # physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
286- # dims_mapping = list(physical_partition_spec)
287- # # If not a stage weight, we handle the repeat dimension offset
288- # if not is_stage_weight:
289- # dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
290- # dims_mapping[dim] = "stage"
291- # dims_mapping = tuple(dims_mapping)
292- # # We add reduced rule only when pspec is given for a stage weight
293- # if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
294- # batch_mesh_axis = ["data", "fsdp"]
295- # reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
296- # pspec = P(*dims_mapping, reduced=set(reduced_mark))
297- # else:
298- # pspec = P(*dims_mapping)
299- # sharding = jax.sharding.NamedSharding(self.mesh, pspec)
300- # return self._maybe_shard_with_name(x, sharding)
301- return x
302-
303278 def get_microbatch_and_repeat_ids (self , loop_iteration ):
304279 """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
305280 non-circular"""
@@ -309,14 +284,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
309284 repeat_ids = microbatches_processed // self .config .num_pipeline_microbatches
310285 return microbatch_ids , repeat_ids
311286
312- def get_microbatch_and_repeat_ids_for_bsw (self , loop_iteration ):
313- """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
314- non-circular"""
315- raw_processed = loop_iteration - self .forwarding_delay * jnp .arange (self .num_stages )
316- repeat_ids = raw_processed // self .config .num_pipeline_microbatches
317- microbatch_ids = jnp .maximum (raw_processed , 0 ) % self .config .num_pipeline_microbatches
318- return microbatch_ids , repeat_ids
319-
320287 def vmap_parallel_gather (
321288 self , weights , physical_partition_spec , repeat_ids , repeat_dim_in_weights , stages_dim_in_weights
322289 ):
@@ -339,18 +306,9 @@ def _gather_one(x, repeat_id):
339306 return jnp .squeeze (jax .lax .dynamic_slice_in_dim (x , repeat_id , 1 , repeat_dim_in_weights ), repeat_dim_in_weights )
340307
341308 gathered_weights_stage_dim = 0
342- repeat_ids = self .shard_dim_by_stages (repeat_ids , 0 , physical_partition_spec = None )
343- # num_repeats x num_stages x *param_dim
344- weights = self .shard_dim_by_stages (
345- weights , stages_dim_in_weights , physical_partition_spec = physical_partition_spec , is_stage_weight = False
346- )
347309 stage_weights = jax .vmap (_gather_one , in_axes = (stages_dim_in_weights , 0 ), out_axes = gathered_weights_stage_dim )(
348310 weights , repeat_ids
349311 )
350- # num_stages x *param_dim
351- stage_weights = self .shard_dim_by_stages (
352- stage_weights , gathered_weights_stage_dim , physical_partition_spec = physical_partition_spec , is_stage_weight = True
353- )
354312 return stage_weights
355313
356314 def vmap_gather (self , xs , ids , ids_dim ):
@@ -374,9 +332,7 @@ def _gather_one(x, i):
374332 replicated_sharding = NamedSharding (self .mesh , P ())
375333 return x .at [idx ].get (out_sharding = replicated_sharding )
376334
377- ids = self .shard_dim_by_stages (ids , 0 , physical_partition_spec = None )
378- outs = jax .vmap (_gather_one , in_axes = (None , 0 ), out_axes = ids_dim )(xs , ids )
379- return self .shard_dim_by_stages (outs , 0 , physical_partition_spec = None )
335+ return jax .vmap (_gather_one , in_axes = (None , 0 ), out_axes = ids_dim )(xs , ids )
380336
381337 def get_new_loop_state (self , output , loop_state ):
382338 """
@@ -533,17 +489,6 @@ def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_s
533489 _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
534490 target_repeat_id = repeat_ids [0 ]
535491
536- # path = ("params", "mlp", "wi_0", "kernel")
537- # path = ("params", "weights")
538-
539- # jax.debug.print(
540- # "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | "
541- # "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}",
542- # iter=loop_iteration, target=target_repeat_id, rids=repeat_ids,
543- # bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)),
544- # bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)),
545- # )
546-
547492 @jax .shard_map (
548493 mesh = self .mesh ,
549494 in_specs = ((bsw_pps , bsw_pps ), P ("stage" )),
@@ -556,14 +501,7 @@ def select_weights_from_bsw(bsw, repeat_id):
556501 bsw [0 ],
557502 bsw [1 ],
558503 )
559- # jax.debug.print(
560- # "Iteration: {iter} | "
561- # "Selected weights mean for Stage {s} with repeat id {i}: {m}",
562- # iter=loop_iteration,
563- # s=jax.lax.axis_index("stage"),
564- # m=maxtext_utils.get_nested_value(weights, path).mean(),
565- # i=repeat_id[0],
566- # )
504+
567505 return weights
568506
569507 weights = select_weights_from_bsw (bsw , repeat_ids )
0 commit comments