3030from flax .linen .spmd import LogicallyPartitioned
3131
3232from MaxText .common_types import Config , MODEL_MODE_TRAIN , EP_AS_CONTEXT , ShardMode
33+ # from MaxText import maxtext_utils
3334from MaxText .sharding import (
3435 maybe_shard_with_logical ,
3536 maybe_shard_with_name ,
@@ -204,12 +205,17 @@ def init_states(self, inputs):
204205
205206 def _init_bsw_from_weights (variables ):
206207 """Buffer space for two copies of weights."""
207- return jax .tree .map (lambda x : jnp .zeros_like (x [:2 ]), variables )
208+ # take idx 0 slice assuming num_layers_per_pipeline_stage=1
209+ return (
210+ jax .tree .map (lambda x : jnp .zeros_like (x [0 ]), variables ),
211+ jax .tree .map (lambda x : jnp .zeros_like (x [0 ]), variables ),
212+ )
208213
209214 if self .is_initializing ():
210215 bsw = None
211216 else :
212- bsw = _init_bsw_from_weights (self .layers .variables )
217+ variables = self ._remove_logically_partition (self .layers .variables )
218+ bsw = _init_bsw_from_weights (variables )
213219
214220 init_loop_state = {
215221 "state_io" : state_io ,
@@ -269,6 +275,31 @@ def select_state_or_input(first_stage_in, shift):
269275 stages_in = self ._maybe_shard_with_logical (stages_in , self .stages_in_logical )
270276 return stages_in
271277
278+ def shard_dim_by_stages (self , x , dim : int , physical_partition_spec : P | None , is_stage_weight : bool = False ):
279+ """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
280+ the specified dimension."""
281+ # placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
282+ # if physical_partition_spec is None:
283+ # dims_mapping = [placeholder] * x.ndim
284+ # else:
285+ # physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
286+ # dims_mapping = list(physical_partition_spec)
287+ # # If not a stage weight, we handle the repeat dimension offset
288+ # if not is_stage_weight:
289+ # dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
290+ # dims_mapping[dim] = "stage"
291+ # dims_mapping = tuple(dims_mapping)
292+ # # We add reduced rule only when pspec is given for a stage weight
293+ # if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
294+ # batch_mesh_axis = ["data", "fsdp"]
295+ # reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
296+ # pspec = P(*dims_mapping, reduced=set(reduced_mark))
297+ # else:
298+ # pspec = P(*dims_mapping)
299+ # sharding = jax.sharding.NamedSharding(self.mesh, pspec)
300+ # return self._maybe_shard_with_name(x, sharding)
301+ return x
302+
272303 def get_microbatch_and_repeat_ids (self , loop_iteration ):
273304 """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
274305 non-circular"""
@@ -278,6 +309,14 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
278309 repeat_ids = microbatches_processed // self .config .num_pipeline_microbatches
279310 return microbatch_ids , repeat_ids
280311
312+ def get_microbatch_and_repeat_ids_for_bsw (self , loop_iteration ):
313+ """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
314+ non-circular"""
315+ raw_processed = loop_iteration - self .forwarding_delay * jnp .arange (self .num_stages )
316+ repeat_ids = raw_processed // self .config .num_pipeline_microbatches
317+ microbatch_ids = jnp .maximum (raw_processed , 0 ) % self .config .num_pipeline_microbatches
318+ return microbatch_ids , repeat_ids
319+
281320 def vmap_parallel_gather (
282321 self , weights , physical_partition_spec , repeat_ids , repeat_dim_in_weights , stages_dim_in_weights
283322 ):
@@ -300,9 +339,18 @@ def _gather_one(x, repeat_id):
300339 return jnp .squeeze (jax .lax .dynamic_slice_in_dim (x , repeat_id , 1 , repeat_dim_in_weights ), repeat_dim_in_weights )
301340
302341 gathered_weights_stage_dim = 0
342+ repeat_ids = self .shard_dim_by_stages (repeat_ids , 0 , physical_partition_spec = None )
343+ # num_repeats x num_stages x *param_dim
344+ weights = self .shard_dim_by_stages (
345+ weights , stages_dim_in_weights , physical_partition_spec = physical_partition_spec , is_stage_weight = False
346+ )
303347 stage_weights = jax .vmap (_gather_one , in_axes = (stages_dim_in_weights , 0 ), out_axes = gathered_weights_stage_dim )(
304348 weights , repeat_ids
305349 )
350+ # num_stages x *param_dim
351+ stage_weights = self .shard_dim_by_stages (
352+ stage_weights , gathered_weights_stage_dim , physical_partition_spec = physical_partition_spec , is_stage_weight = True
353+ )
306354 return stage_weights
307355
308356 def vmap_gather (self , xs , ids , ids_dim ):
@@ -326,8 +374,9 @@ def _gather_one(x, i):
326374 replicated_sharding = NamedSharding (self .mesh , P ())
327375 return x .at [idx ].get (out_sharding = replicated_sharding )
328376
377+ ids = self .shard_dim_by_stages (ids , 0 , physical_partition_spec = None )
329378 outs = jax .vmap (_gather_one , in_axes = (None , 0 ), out_axes = ids_dim )(xs , ids )
330- return outs
379+ return self . shard_dim_by_stages ( outs , 0 , physical_partition_spec = None )
331380
332381 def get_new_loop_state (self , output , loop_state ):
333382 """
@@ -471,20 +520,53 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi
471520 For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However
472521 for circular pipelines each stage grabs only the weights corresponding to the current repeat.
473522 """
523+ pipeline_weights = self ._remove_logically_partition (pipeline_weights )
474524 if self .config .num_pipeline_repeats > 1 :
475- return self .get_current_weights_from_bsw (bsw , loop_iteration , physical_partition_spec = physical_partition_spec )
476- else :
477- return pipeline_weights
525+ pipeline_weights = self .get_current_weights_from_bsw (
526+ bsw , loop_iteration , physical_partition_spec = physical_partition_spec
527+ )
528+ return pipeline_weights
478529
479- def get_current_weights_from_bsw (self , bsw , loop_iteration , physical_partition_spec = None ):
530+ def get_current_weights_from_bsw (self , bsw , loop_iteration , physical_partition_spec ):
480531 """Collect and gather weights from given bsw (buffer sliding window)"""
532+ bsw_pps = jax .tree .map (self ._remove_fsdp_from_physical_partition_spec , physical_partition_spec )
533+ _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
534+ target_repeat_id = repeat_ids [0 ]
481535
482- def _get_bsw_idx (loop_iteration ):
483- _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
484- bsw_ids = (repeat_ids == repeat_ids [0 ]).astype (
485- jnp .int32
486- ) # For early repeats this might return true when it should be false
487- return bsw_ids
536+ # path = ("params", "mlp", "wi_0", "kernel")
537+ # path = ("params", "weights")
538+
539+ # jax.debug.print(
540+ # "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | "
541+ # "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}",
542+ # iter=loop_iteration, target=target_repeat_id, rids=repeat_ids,
543+ # bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)),
544+ # bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)),
545+ # )
546+
547+ @jax .shard_map (
548+ mesh = self .mesh ,
549+ in_specs = ((bsw_pps , bsw_pps ), P ("stage" )),
550+ out_specs = (bsw_pps ),
551+ check_vma = True ,
552+ )
553+ def select_weights_from_bsw (bsw , repeat_id ):
554+ weights = jax .tree .map (
555+ lambda x , y : jax .lax .select (repeat_id [0 ] == target_repeat_id , y , x ),
556+ bsw [0 ],
557+ bsw [1 ],
558+ )
559+ # jax.debug.print(
560+ # "Iteration: {iter} | "
561+ # "Selected weights mean for Stage {s} with repeat id {i}: {m}",
562+ # iter=loop_iteration,
563+ # s=jax.lax.axis_index("stage"),
564+ # m=maxtext_utils.get_nested_value(weights, path).mean(),
565+ # i=repeat_id[0],
566+ # )
567+ return weights
568+
569+ weights = select_weights_from_bsw (bsw , repeat_ids )
488570
489571 circular_metadata_params = {
490572 nn .PARTITION_NAME : "circular_repeats" ,
@@ -494,24 +576,10 @@ def _get_bsw_idx(loop_iteration):
494576 "optimizer_dims_mapping" : None ,
495577 }
496578 weights = meta .remove_axis (
497- bsw , 0 , circular_metadata_params
579+ weights , 0 , circular_metadata_params
498580 ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular
499581 # entry per stage.
500- weights = self ._remove_logically_partition (weights )
501582
502- def gather_weights_for_stages_in (w , spec = None ):
503- return self .vmap_parallel_gather (
504- w ,
505- repeat_ids = _get_bsw_idx (loop_iteration ),
506- repeat_dim_in_weights = 0 ,
507- stages_dim_in_weights = 1 ,
508- physical_partition_spec = spec ,
509- )
510-
511- if physical_partition_spec is None :
512- weights = jax .tree .map (gather_weights_for_stages_in , weights )
513- else :
514- weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
515583 return weights
516584
517585 @staticmethod
@@ -544,40 +612,50 @@ def find_fsdp(pspec):
544612
545613 return jax .tree .map (find_fsdp , physical_partition_spec )
546614
547- def bsw_all_gather_over_fsdp (self , bsw , physical_partition_spec , loop_iteration ):
615+ def bsw_all_gather_over_fsdp (self , weights , bsw , physical_partition_spec , loop_iteration ):
548616 """All gather bsw over fsdp mesh axis using shardmap."""
549- pps_no_fsdp = jax .tree .map (self ._remove_fsdp_from_physical_partition_spec , physical_partition_spec )
617+ bsw_pps = self ._generate_bsw_pps_from_pps (physical_partition_spec )
618+ repeat_weights_pps = jax .tree .map (lambda p : P (* p [1 :]), physical_partition_spec )
550619 fsdp_idx = self .get_fsdp_index_pytree (physical_partition_spec )
551620
552621 _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration + 1 )
553622
623+ def gather_weights_for_stages_in (w , spec ):
624+ return self .vmap_parallel_gather (
625+ w , repeat_ids = repeat_ids , repeat_dim_in_weights = 0 , stages_dim_in_weights = 1 , physical_partition_spec = spec
626+ )
627+
628+ if physical_partition_spec is None :
629+ repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights )
630+ else :
631+ repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
632+
633+ circular_metadata_params = {
634+ nn .PARTITION_NAME : "circular_repeats" ,
635+ "sub_weight_split_dims_mapping" : (None ,),
636+ "is_initializing" : self .is_initializing (),
637+ "x_times" : self .config .num_pipeline_repeats ,
638+ "optimizer_dims_mapping" : None ,
639+ }
640+ repeat_weights = meta .remove_axis (repeat_weights , 0 , circular_metadata_params )
641+
554642 @jax .shard_map (
555643 mesh = self .mesh ,
556- in_specs = (physical_partition_spec , pps_no_fsdp , None , None ),
557- out_specs = pps_no_fsdp ,
644+ in_specs = (repeat_weights_pps , ( bsw_pps , bsw_pps ) , None ),
645+ out_specs = ( bsw_pps , bsw_pps ) ,
558646 check_vma = True ,
559647 )
560- def _all_gather_inner (variables , cur_bsw , repeat_idx , fsdp_idx ):
561- new_variables = jax .tree .map (
562- lambda x : jax .lax .dynamic_slice_in_dim (x , repeat_idx , 1 ),
563- variables ,
564- )
565-
648+ def _all_gather_inner (sharded_weights , cur_bsw , fsdp_idx ):
566649 def _all_gather_invariant (x , i ):
567650 if i >= 0 :
568- return all_gather_invariant (x , axis_name = "fsdp" , axis = i , tiled = True )
651+ return all_gather_invariant (x , axis_name = "fsdp" , axis = i - 1 , tiled = True )
569652 return x
570653
571- new_variables = jax .tree .map (_all_gather_invariant , new_variables , fsdp_idx )
572-
573- def shift_and_insert (bsw_leaf , new_leaf ):
574- updated_bsw = bsw_leaf .at [0 ].set (bsw_leaf [1 ])
575- updated_bsw = updated_bsw .at [1 ].set (jnp .squeeze (new_leaf , axis = 0 ))
576- return updated_bsw
654+ new_variables = jax .tree .map (_all_gather_invariant , sharded_weights , fsdp_idx )
577655
578- return jax . tree . map ( shift_and_insert , cur_bsw , new_variables )
656+ return ( cur_bsw [ 1 ] , new_variables )
579657
580- return _all_gather_inner (self . layers . variables , bsw , repeat_ids [ 0 ] , fsdp_idx )
658+ return _all_gather_inner (repeat_weights , bsw , fsdp_idx )
581659
582660 def get_vmap_func_for_init (self ):
583661 """This vmap func is used to initialize the weights only on init."""
@@ -648,7 +726,7 @@ def run_one_iteration(
648726 deterministic ,
649727 model_mode ,
650728 decoder_layer_instance ,
651- logical_partition_spec = None ,
729+ logical_partition_spec ,
652730 ):
653731 """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
654732 and update the loop state."""
@@ -811,6 +889,13 @@ def _remove_fsdp_from_physical_partition_spec(pps):
811889 return P (* new_spec )
812890 return pps
813891
892+ def _generate_bsw_pps_from_pps (self , physical_partition_spec ):
893+ """Create bsw physical partition spec from weight physical partition spec."""
894+ return jax .tree .map (
895+ lambda pps : P (* self ._remove_fsdp_from_physical_partition_spec (pps )[1 :]),
896+ physical_partition_spec ,
897+ )
898+
814899 @nn .compact
815900 def __call__ (
816901 self ,
@@ -966,8 +1051,9 @@ def run_iteration_scannable(model, loop_state):
9661051 )
9671052
9681053 def run_one_repeat_scannable (model , loop_state ):
1054+ weights = model ._remove_logically_partition (model .layers .variables ) # pylint: disable=protected-access
9691055 loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
970- loop_state ["bsw" ], physical_partition_spec , loop_state ["loop_iteration" ]
1056+ weights , loop_state ["bsw" ], physical_partition_spec , loop_state ["loop_iteration" ]
9711057 )
9721058
9731059 if model .config .scan_pipeline_iterations :
@@ -997,65 +1083,6 @@ def run_one_repeat_scannable(model, loop_state):
9971083 policy = self .get_pipeline_remat_policy (),
9981084 )
9991085
1000- def run_real_repeats (model , loop_state ):
1001- if self .config .scan_pipeline_repeats :
1002- run_repeats_scanned = nn .scan (
1003- run_one_repeat_scannable ,
1004- variable_axes = {
1005- "summaries" : 0 ,
1006- "aux_loss" : 0 ,
1007- "intermediates" : 0 ,
1008- "hyper_params" : 0 ,
1009- },
1010- variable_broadcast = variable_broadcast ,
1011- variable_carry = variable_carry ,
1012- split_rngs = {"random" : True },
1013- length = model .config .num_pipeline_repeats ,
1014- )
1015- loop_state , _ = run_repeats_scanned (model , loop_state )
1016- else :
1017- for _ in range (model .config .num_pipeline_repeats ): # remat and scan outer loop
1018- loop_state , _ = run_one_repeat_scannable (model , loop_state )
1019- return loop_state
1020-
1021- run_real_repeats = nn .remat (
1022- run_real_repeats ,
1023- prevent_cse = not self .config .scan_pipeline_iterations ,
1024- policy = self .get_pipeline_remat_policy (),
1025- )
1026-
1027- def run_bubble_iterations_scannable (model , loop_state ):
1028- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
1029- loop_state ["bsw" ], physical_partition_spec , loop_state ["loop_iteration" ]
1030- )
1031-
1032- if model .config .scan_pipeline_iterations :
1033- run_one_repeat_scanned = nn .scan (
1034- run_iteration_scannable ,
1035- variable_axes = {
1036- "summaries" : 0 ,
1037- "aux_loss" : 0 ,
1038- "intermediates" : 0 ,
1039- "hyper_params" : 0 ,
1040- },
1041- variable_broadcast = variable_broadcast ,
1042- variable_carry = variable_carry ,
1043- # Dropout/aqt keys will be split for each iteration.
1044- split_rngs = {"random" : True },
1045- length = bubble_iterations ,
1046- )
1047- loop_state , _ = run_one_repeat_scanned (model , loop_state )
1048- else :
1049- for _ in range (model .config .num_pipeline_microbatches ):
1050- loop_state , _ = run_iteration_scannable (model , loop_state )
1051- return loop_state , None
1052-
1053- run_bubble_iterations_scannable = nn .remat (
1054- run_bubble_iterations_scannable ,
1055- prevent_cse = not self .config .scan_pipeline_iterations ,
1056- policy = self .get_pipeline_remat_policy (),
1057- )
1058-
10591086 def run_all_iterations (model , loop_state ):
10601087 if self .config .scan_pipeline_repeats :
10611088 run_repeats_scanned = nn .scan (
@@ -1073,7 +1100,7 @@ def run_all_iterations(model, loop_state):
10731100 )
10741101
10751102 run_bubbles_scanned = nn .scan (
1076- run_bubble_iterations_scannable ,
1103+ run_iteration_scannable ,
10771104 variable_axes = {
10781105 "summaries" : 0 ,
10791106 "aux_loss" : 0 ,
@@ -1083,9 +1110,10 @@ def run_all_iterations(model, loop_state):
10831110 variable_broadcast = variable_broadcast ,
10841111 variable_carry = variable_carry ,
10851112 split_rngs = {"random" : True },
1086- length = model . config . num_pipeline_repeats ,
1113+ length = bubble_iterations ,
10871114 )
10881115 loop_state , _ = run_repeats_scanned (model , loop_state )
1116+ loop_state ["bsw" ] = (loop_state ["bsw" ][1 ], jax .tree .map (jnp .zeros_like , loop_state ["bsw" ][1 ]))
10891117 loop_state , _ = run_bubbles_scanned (model , loop_state )
10901118 else :
10911119 for _ in range (model .config .num_pipeline_repeats ): # remat and scan outer loop
0 commit comments