Skip to content

Commit 51ddc35

Browse files
committed
working all gather insertion
1 parent 062a066 commit 51ddc35

3 files changed

Lines changed: 144 additions & 116 deletions

File tree

src/MaxText/layers/pipeline.py

Lines changed: 138 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from flax.linen.spmd import LogicallyPartitioned
3131

3232
from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode
33+
# from MaxText import maxtext_utils
3334
from MaxText.sharding import (
3435
maybe_shard_with_logical,
3536
maybe_shard_with_name,
@@ -204,12 +205,17 @@ def init_states(self, inputs):
204205

205206
def _init_bsw_from_weights(variables):
206207
"""Buffer space for two copies of weights."""
207-
return jax.tree.map(lambda x: jnp.zeros_like(x[:2]), variables)
208+
# take idx 0 slice assuming num_layers_per_pipeline_stage=1
209+
return (
210+
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
211+
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
212+
)
208213

209214
if self.is_initializing():
210215
bsw = None
211216
else:
212-
bsw = _init_bsw_from_weights(self.layers.variables)
217+
variables = self._remove_logically_partition(self.layers.variables)
218+
bsw = _init_bsw_from_weights(variables)
213219

214220
init_loop_state = {
215221
"state_io": state_io,
@@ -269,6 +275,31 @@ def select_state_or_input(first_stage_in, shift):
269275
stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical)
270276
return stages_in
271277

278+
def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False):
279+
"""Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
280+
the specified dimension."""
281+
# placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
282+
# if physical_partition_spec is None:
283+
# dims_mapping = [placeholder] * x.ndim
284+
# else:
285+
# physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
286+
# dims_mapping = list(physical_partition_spec)
287+
# # If not a stage weight, we handle the repeat dimension offset
288+
# if not is_stage_weight:
289+
# dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
290+
# dims_mapping[dim] = "stage"
291+
# dims_mapping = tuple(dims_mapping)
292+
# # We add reduced rule only when pspec is given for a stage weight
293+
# if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
294+
# batch_mesh_axis = ["data", "fsdp"]
295+
# reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
296+
# pspec = P(*dims_mapping, reduced=set(reduced_mark))
297+
# else:
298+
# pspec = P(*dims_mapping)
299+
# sharding = jax.sharding.NamedSharding(self.mesh, pspec)
300+
# return self._maybe_shard_with_name(x, sharding)
301+
return x
302+
272303
def get_microbatch_and_repeat_ids(self, loop_iteration):
273304
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
274305
non-circular"""
@@ -278,6 +309,14 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
278309
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
279310
return microbatch_ids, repeat_ids
280311

312+
def get_microbatch_and_repeat_ids_for_bsw(self, loop_iteration):
313+
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
314+
non-circular"""
315+
raw_processed = loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages)
316+
repeat_ids = raw_processed // self.config.num_pipeline_microbatches
317+
microbatch_ids = jnp.maximum(raw_processed, 0) % self.config.num_pipeline_microbatches
318+
return microbatch_ids, repeat_ids
319+
281320
def vmap_parallel_gather(
282321
self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights
283322
):
@@ -300,9 +339,18 @@ def _gather_one(x, repeat_id):
300339
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights)
301340

302341
gathered_weights_stage_dim = 0
342+
repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None)
343+
# num_repeats x num_stages x *param_dim
344+
weights = self.shard_dim_by_stages(
345+
weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False
346+
)
303347
stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(
304348
weights, repeat_ids
305349
)
350+
# num_stages x *param_dim
351+
stage_weights = self.shard_dim_by_stages(
352+
stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True
353+
)
306354
return stage_weights
307355

308356
def vmap_gather(self, xs, ids, ids_dim):
@@ -326,8 +374,9 @@ def _gather_one(x, i):
326374
replicated_sharding = NamedSharding(self.mesh, P())
327375
return x.at[idx].get(out_sharding=replicated_sharding)
328376

377+
ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None)
329378
outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
330-
return outs
379+
return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None)
331380

332381
def get_new_loop_state(self, output, loop_state):
333382
"""
@@ -471,20 +520,53 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi
471520
For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However
472521
for circular pipelines each stage grabs only the weights corresponding to the current repeat.
473522
"""
523+
pipeline_weights = self._remove_logically_partition(pipeline_weights)
474524
if self.config.num_pipeline_repeats > 1:
475-
return self.get_current_weights_from_bsw(bsw, loop_iteration, physical_partition_spec=physical_partition_spec)
476-
else:
477-
return pipeline_weights
525+
pipeline_weights = self.get_current_weights_from_bsw(
526+
bsw, loop_iteration, physical_partition_spec=physical_partition_spec
527+
)
528+
return pipeline_weights
478529

479-
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec=None):
530+
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec):
480531
"""Collect and gather weights from given bsw (buffer sliding window)"""
532+
bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec)
533+
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
534+
target_repeat_id = repeat_ids[0]
481535

482-
def _get_bsw_idx(loop_iteration):
483-
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
484-
bsw_ids = (repeat_ids == repeat_ids[0]).astype(
485-
jnp.int32
486-
) # For early repeats this might return true when it should be false
487-
return bsw_ids
536+
# path = ("params", "mlp", "wi_0", "kernel")
537+
# path = ("params", "weights")
538+
539+
# jax.debug.print(
540+
# "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | "
541+
# "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}",
542+
# iter=loop_iteration, target=target_repeat_id, rids=repeat_ids,
543+
# bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)),
544+
# bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)),
545+
# )
546+
547+
@jax.shard_map(
548+
mesh=self.mesh,
549+
in_specs=((bsw_pps, bsw_pps), P("stage")),
550+
out_specs=(bsw_pps),
551+
check_vma=True,
552+
)
553+
def select_weights_from_bsw(bsw, repeat_id):
554+
weights = jax.tree.map(
555+
lambda x, y: jax.lax.select(repeat_id[0] == target_repeat_id, y, x),
556+
bsw[0],
557+
bsw[1],
558+
)
559+
# jax.debug.print(
560+
# "Iteration: {iter} | "
561+
# "Selected weights mean for Stage {s} with repeat id {i}: {m}",
562+
# iter=loop_iteration,
563+
# s=jax.lax.axis_index("stage"),
564+
# m=maxtext_utils.get_nested_value(weights, path).mean(),
565+
# i=repeat_id[0],
566+
# )
567+
return weights
568+
569+
weights = select_weights_from_bsw(bsw, repeat_ids)
488570

489571
circular_metadata_params = {
490572
nn.PARTITION_NAME: "circular_repeats",
@@ -494,24 +576,10 @@ def _get_bsw_idx(loop_iteration):
494576
"optimizer_dims_mapping": None,
495577
}
496578
weights = meta.remove_axis(
497-
bsw, 0, circular_metadata_params
579+
weights, 0, circular_metadata_params
498580
) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular
499581
# entry per stage.
500-
weights = self._remove_logically_partition(weights)
501582

502-
def gather_weights_for_stages_in(w, spec=None):
503-
return self.vmap_parallel_gather(
504-
w,
505-
repeat_ids=_get_bsw_idx(loop_iteration),
506-
repeat_dim_in_weights=0,
507-
stages_dim_in_weights=1,
508-
physical_partition_spec=spec,
509-
)
510-
511-
if physical_partition_spec is None:
512-
weights = jax.tree.map(gather_weights_for_stages_in, weights)
513-
else:
514-
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
515583
return weights
516584

517585
@staticmethod
@@ -544,40 +612,50 @@ def find_fsdp(pspec):
544612

545613
return jax.tree.map(find_fsdp, physical_partition_spec)
546614

547-
def bsw_all_gather_over_fsdp(self, bsw, physical_partition_spec, loop_iteration):
615+
def bsw_all_gather_over_fsdp(self, weights, bsw, physical_partition_spec, loop_iteration):
548616
"""All gather bsw over fsdp mesh axis using shardmap."""
549-
pps_no_fsdp = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec)
617+
bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec)
618+
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
550619
fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec)
551620

552621
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration + 1)
553622

623+
def gather_weights_for_stages_in(w, spec):
624+
return self.vmap_parallel_gather(
625+
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
626+
)
627+
628+
if physical_partition_spec is None:
629+
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights)
630+
else:
631+
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
632+
633+
circular_metadata_params = {
634+
nn.PARTITION_NAME: "circular_repeats",
635+
"sub_weight_split_dims_mapping": (None,),
636+
"is_initializing": self.is_initializing(),
637+
"x_times": self.config.num_pipeline_repeats,
638+
"optimizer_dims_mapping": None,
639+
}
640+
repeat_weights = meta.remove_axis(repeat_weights, 0, circular_metadata_params)
641+
554642
@jax.shard_map(
555643
mesh=self.mesh,
556-
in_specs=(physical_partition_spec, pps_no_fsdp, None, None),
557-
out_specs=pps_no_fsdp,
644+
in_specs=(repeat_weights_pps, (bsw_pps, bsw_pps), None),
645+
out_specs=(bsw_pps, bsw_pps),
558646
check_vma=True,
559647
)
560-
def _all_gather_inner(variables, cur_bsw, repeat_idx, fsdp_idx):
561-
new_variables = jax.tree.map(
562-
lambda x: jax.lax.dynamic_slice_in_dim(x, repeat_idx, 1),
563-
variables,
564-
)
565-
648+
def _all_gather_inner(sharded_weights, cur_bsw, fsdp_idx):
566649
def _all_gather_invariant(x, i):
567650
if i >= 0:
568-
return all_gather_invariant(x, axis_name="fsdp", axis=i, tiled=True)
651+
return all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True)
569652
return x
570653

571-
new_variables = jax.tree.map(_all_gather_invariant, new_variables, fsdp_idx)
572-
573-
def shift_and_insert(bsw_leaf, new_leaf):
574-
updated_bsw = bsw_leaf.at[0].set(bsw_leaf[1])
575-
updated_bsw = updated_bsw.at[1].set(jnp.squeeze(new_leaf, axis=0))
576-
return updated_bsw
654+
new_variables = jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx)
577655

578-
return jax.tree.map(shift_and_insert, cur_bsw, new_variables)
656+
return (cur_bsw[1], new_variables)
579657

580-
return _all_gather_inner(self.layers.variables, bsw, repeat_ids[0], fsdp_idx)
658+
return _all_gather_inner(repeat_weights, bsw, fsdp_idx)
581659

582660
def get_vmap_func_for_init(self):
583661
"""This vmap func is used to initialize the weights only on init."""
@@ -648,7 +726,7 @@ def run_one_iteration(
648726
deterministic,
649727
model_mode,
650728
decoder_layer_instance,
651-
logical_partition_spec=None,
729+
logical_partition_spec,
652730
):
653731
"""Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
654732
and update the loop state."""
@@ -811,6 +889,13 @@ def _remove_fsdp_from_physical_partition_spec(pps):
811889
return P(*new_spec)
812890
return pps
813891

892+
def _generate_bsw_pps_from_pps(self, physical_partition_spec):
893+
"""Create bsw physical partition spec from weight physical partition spec."""
894+
return jax.tree.map(
895+
lambda pps: P(*self._remove_fsdp_from_physical_partition_spec(pps)[1:]),
896+
physical_partition_spec,
897+
)
898+
814899
@nn.compact
815900
def __call__(
816901
self,
@@ -966,8 +1051,9 @@ def run_iteration_scannable(model, loop_state):
9661051
)
9671052

9681053
def run_one_repeat_scannable(model, loop_state):
1054+
weights = model._remove_logically_partition(model.layers.variables) # pylint: disable=protected-access
9691055
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
970-
loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"]
1056+
weights, loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"]
9711057
)
9721058

9731059
if model.config.scan_pipeline_iterations:
@@ -997,65 +1083,6 @@ def run_one_repeat_scannable(model, loop_state):
9971083
policy=self.get_pipeline_remat_policy(),
9981084
)
9991085

1000-
def run_real_repeats(model, loop_state):
1001-
if self.config.scan_pipeline_repeats:
1002-
run_repeats_scanned = nn.scan(
1003-
run_one_repeat_scannable,
1004-
variable_axes={
1005-
"summaries": 0,
1006-
"aux_loss": 0,
1007-
"intermediates": 0,
1008-
"hyper_params": 0,
1009-
},
1010-
variable_broadcast=variable_broadcast,
1011-
variable_carry=variable_carry,
1012-
split_rngs={"random": True},
1013-
length=model.config.num_pipeline_repeats,
1014-
)
1015-
loop_state, _ = run_repeats_scanned(model, loop_state)
1016-
else:
1017-
for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop
1018-
loop_state, _ = run_one_repeat_scannable(model, loop_state)
1019-
return loop_state
1020-
1021-
run_real_repeats = nn.remat(
1022-
run_real_repeats,
1023-
prevent_cse=not self.config.scan_pipeline_iterations,
1024-
policy=self.get_pipeline_remat_policy(),
1025-
)
1026-
1027-
def run_bubble_iterations_scannable(model, loop_state):
1028-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
1029-
loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"]
1030-
)
1031-
1032-
if model.config.scan_pipeline_iterations:
1033-
run_one_repeat_scanned = nn.scan(
1034-
run_iteration_scannable,
1035-
variable_axes={
1036-
"summaries": 0,
1037-
"aux_loss": 0,
1038-
"intermediates": 0,
1039-
"hyper_params": 0,
1040-
},
1041-
variable_broadcast=variable_broadcast,
1042-
variable_carry=variable_carry,
1043-
# Dropout/aqt keys will be split for each iteration.
1044-
split_rngs={"random": True},
1045-
length=bubble_iterations,
1046-
)
1047-
loop_state, _ = run_one_repeat_scanned(model, loop_state)
1048-
else:
1049-
for _ in range(model.config.num_pipeline_microbatches):
1050-
loop_state, _ = run_iteration_scannable(model, loop_state)
1051-
return loop_state, None
1052-
1053-
run_bubble_iterations_scannable = nn.remat(
1054-
run_bubble_iterations_scannable,
1055-
prevent_cse=not self.config.scan_pipeline_iterations,
1056-
policy=self.get_pipeline_remat_policy(),
1057-
)
1058-
10591086
def run_all_iterations(model, loop_state):
10601087
if self.config.scan_pipeline_repeats:
10611088
run_repeats_scanned = nn.scan(
@@ -1073,7 +1100,7 @@ def run_all_iterations(model, loop_state):
10731100
)
10741101

10751102
run_bubbles_scanned = nn.scan(
1076-
run_bubble_iterations_scannable,
1103+
run_iteration_scannable,
10771104
variable_axes={
10781105
"summaries": 0,
10791106
"aux_loss": 0,
@@ -1083,9 +1110,10 @@ def run_all_iterations(model, loop_state):
10831110
variable_broadcast=variable_broadcast,
10841111
variable_carry=variable_carry,
10851112
split_rngs={"random": True},
1086-
length=model.config.num_pipeline_repeats,
1113+
length=bubble_iterations,
10871114
)
10881115
loop_state, _ = run_repeats_scanned(model, loop_state)
1116+
loop_state["bsw"] = (loop_state["bsw"][1], jax.tree.map(jnp.zeros_like, loop_state["bsw"][1]))
10891117
loop_state, _ = run_bubbles_scanned(model, loop_state)
10901118
else:
10911119
for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ prometheus_port: 0
881881
enable_jax_profiler: False
882882
jax_profiler_port: 9999
883883

884-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
884+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
885885
debug_sharding: False # Prints model weights sharding info
886886

887887
# Checkpoint Structured logging

0 commit comments

Comments
 (0)