Skip to content

Commit b6dc0fc

Browse files
committed
PR #3501: DO NOT MERGE Clean check vma
Imported from GitHub PR #3501 # Description Start with a short description of what the PR does and how this is a change from the past. The rest of the description includes relevant details and context, examples: - why is this change being made, - the problem being solved and any relevant context, - why this is a good solution, - some information about the specific implementation, - shortcomings of the solution and possible future improvements. If the change fixes a bug or a Github issue, please include a link, e.g.,: FIXES: b/123456 FIXES: #123456 *Notice 1:* Once all tests pass, the "pull ready" label will automatically be assigned. This label is used for administrative purposes. Please do not add it manually. *Notice 2:* For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests. # Tests Please describe how you tested this change, and include any instructions and/or commands to reproduce. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [ ] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [ ] I have necessary comments in my code, particularly in hard-to-understand areas. - [ ] I have run end-to-end tests tests and provided workload links above if applicable. - [ ] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files). Copybara import of the project: -- ec1b98b by Shuwen Fang <shuwenf@google.com>: moe formatting changes -- 7618024 by Shuwen Fang <shuwenf@google.com>: check vma changes -- 41b53c9 by Shuwen Fang <shuwenf@google.com>: check vma changes Merging this change closes #3501 FUTURE_COPYBARA_INTEGRATE_REVIEW=#3501 from AI-Hypercomputer:clean-check-vma 41b53c9 PiperOrigin-RevId: 889398062
1 parent 87b1861 commit b6dc0fc

5 files changed

Lines changed: 54 additions & 46 deletions

File tree

src/maxtext/kernels/megablox/backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def _calculate_bytes(x: jax.Array | qpl.QArray) -> int:
298298
"tiling",
299299
"transpose_rhs",
300300
"interpret",
301+
"vma_axes",
301302
],
302303
)
303304
def gmm(
@@ -310,6 +311,7 @@ def gmm(
310311
existing_out: jnp.ndarray | None = None,
311312
transpose_rhs: bool = False,
312313
interpret: bool = False,
314+
vma_axes: tuple = tuple(),
313315
) -> jnp.ndarray:
314316
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
315317
@@ -522,7 +524,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
522524
}
523525
call_gmm = qpl.pallas_call(
524526
kernel,
525-
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
527+
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type, vma=set(vma_axes)),
526528
grid_spec=pltpu.PrefetchScalarGridSpec(
527529
num_scalar_prefetch=2,
528530
in_specs=[
@@ -558,13 +560,15 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
558560
return out
559561

560562

563+
# calculates drhs - expert weight gradient
561564
@functools.partial(
562565
jax.jit,
563566
static_argnames=[
564567
"preferred_element_type",
565568
"tiling",
566569
"num_actual_groups",
567570
"interpret",
571+
"vma_axes",
568572
],
569573
)
570574
def tgmm(
@@ -577,6 +581,7 @@ def tgmm(
577581
num_actual_groups: int | None = None,
578582
existing_out: jnp.ndarray | None = None,
579583
interpret: bool = False,
584+
vma_axes: tuple = tuple(),
580585
) -> jnp.ndarray:
581586
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
582587
@@ -773,9 +778,10 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
773778
"prefer_element_type": jnp.dtype(preferred_element_type).name,
774779
"num_actual_groups": num_actual_groups,
775780
}
781+
# computes
776782
call_gmm = qpl.pallas_call(
777783
kernel,
778-
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type),
784+
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type, vma=set(vma_axes)),
779785
grid_spec=pltpu.PrefetchScalarGridSpec(
780786
num_scalar_prefetch=2,
781787
in_specs=[

src/maxtext/kernels/megablox/ops.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def gmm(
4444
weight_gather_axes: List[Tuple[str, int]] | None = None,
4545
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
4646
combine_scopes: bool = False,
47-
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
48-
qwix_rule: qwix.QtRule | None = None,
47+
lhs_vma_axes: tuple = tuple(),
48+
rhs_vma_axes: tuple = tuple(),
4949
):
5050
"""Grouped matrix multiplication operation."""
5151
quantization_rule = None
@@ -64,9 +64,13 @@ def gmm(
6464
act_calibration_method="absmax",
6565
)
6666

67-
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
67+
_gmm_fwd_vma = functools.partial(_gmm_fwd, lhs_vma_axes=tuple())
68+
_gmm_bwd_vma = functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype, lhs_vma_axes=tuple(), rhs_vma_axes=("expert",))
69+
gmm_fwd_bwd = lambda *args: _gmm_fwd_vma(*args)[0] # pylint: disable=C3001
70+
# defined custom backward propagation to be more efficient
71+
# computes: dlhs: gradients of activations (for previous layers); drhs: gradients of weights
6872
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
69-
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
73+
gmm_fwd_bwd.defvjp(_gmm_fwd_vma, _gmm_bwd_vma)
7074
return gmm_fwd_bwd(
7175
lhs,
7276
rhs,
@@ -85,6 +89,7 @@ def gmm(
8589
)
8690

8791

92+
# wraps backend kernel
8893
def _gmm_fwd(
8994
lhs: jnp.ndarray,
9095
rhs: jnp.ndarray,
@@ -100,6 +105,7 @@ def _gmm_fwd(
100105
quantization_rule: qwix.QtRule | None = None,
101106
use_tokamax_backend: bool = False,
102107
weight_gather_axes: List[Tuple[str, int]] | None = None,
108+
lhs_vma_axes: tuple = tuple(),
103109
) -> tuple[
104110
jnp.ndarray,
105111
tuple[
@@ -129,7 +135,7 @@ def _gmm_fwd(
129135
calibration_method=quantization_rule.weight_calibration_method,
130136
)
131137
# QAG is only supported for following conditions
132-
if use_tokamax_backend:
138+
if use_tokamax_backend: # false
133139
if quantization_rule and quantization_rule.bwd_qtype:
134140
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
135141
if weight_gather_axes:
@@ -159,10 +165,12 @@ def _gmm_fwd(
159165
existing_out,
160166
transpose_rhs=transpose_rhs,
161167
interpret=interpret,
168+
vma_axes=tuple(),
162169
)
163170
return out, (lhs, rhs, group_sizes, group_offset)
164171

165172

173+
# custom backward function
166174
def _gmm_bwd(
167175
lhs_dtype: jax.typing.DTypeLike,
168176
rhs_dtype: jax.typing.DTypeLike,
@@ -182,6 +190,8 @@ def _gmm_bwd(
182190
jnp.ndarray | None,
183191
],
184192
grad: jnp.ndarray,
193+
lhs_vma_axes: tuple = tuple(), # axes for SiLU output - fsdp
194+
rhs_vma_axes: tuple = tuple(), # axes for W_out - expert
185195
) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]:
186196
"""Backward function for throughput GMM VJP."""
187197
del preferred_element_type
@@ -223,7 +233,7 @@ def _gmm_bwd(
223233
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1],
224234
calibration_method=quantization_rule.bwd_calibration_method,
225235
)
226-
if use_tokamax_backend:
236+
if use_tokamax_backend: # false
227237
dlhs = tokamax_backend.gmm(
228238
lhs=dlhs_dout,
229239
rhs=rhs,
@@ -263,6 +273,7 @@ def _gmm_bwd(
263273
group_offset,
264274
transpose_rhs=not transpose_rhs,
265275
interpret=interpret,
276+
vma_axes=lhs_vma_axes,
266277
)
267278
drhs = backend.tgmm(
268279
lhs.swapaxes(0, 1),
@@ -273,6 +284,7 @@ def _gmm_bwd(
273284
group_offset,
274285
num_actual_groups,
275286
interpret=interpret,
287+
vma_axes=("expert",),
276288
)
277289

278290
# NOTE: If the rhs transposition is fused into the forward pass we need to

src/maxtext/layers/moe.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -640,9 +640,7 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
640640
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
641641
# sort inputs for number of selected experts
642642
replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
643-
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(
644-
self.dtype
645-
)
643+
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(self.dtype)
646644
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)
647645
# Return the experts for each sorted input.
648646
expert_indices = jnp.arange(self.num_experts)
@@ -893,9 +891,13 @@ def sparse_matmul(
893891
):
894892
"""Perform sparse matrix multiplication of inputs and Experts."""
895893

896-
def gmm(
897-
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
898-
):
894+
vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1)
895+
use_vma = not self.config.use_tokamax_gmm
896+
897+
vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1)
898+
use_vma = not self.config.use_tokamax_gmm
899+
900+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes):
899901
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
900902
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
901903
tokamax_group_sizes = group_sizes
@@ -1033,14 +1035,13 @@ def gmm(
10331035
batch_logical_axis = "activation_batch_no_exp_moe"
10341036

10351037
if self.get_tensor_transpose_parallelism_size() > 1:
1036-
input_partition_pspec = self._logical_to_mesh_axes(
1037-
(batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")
1038-
)
1038+
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))
10391039
w0_bias_pspec = self._logical_to_mesh_axes(("exp", None))
10401040
w1_bias_pspec = self._logical_to_mesh_axes(("exp", None))
10411041
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
10421042
else:
1043-
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
1043+
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
1044+
# expert weights are sharded by exp
10441045
w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
10451046
w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
10461047
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
@@ -1103,7 +1104,7 @@ def gmm(
11031104
P(), # Handle None or replicate the output
11041105
P(), # Handle None or replicate the output
11051106
),
1106-
check_vma=False,
1107+
check_vma=use_vma,
11071108
)
11081109
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
11091110
batch_size, sequence_length, _ = x.shape
@@ -1119,8 +1120,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11191120

11201121
# Duplicate inputs to all expert shards.
11211122
x, logits, pre_bias_logits = tuple(
1122-
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
1123-
for z in (x, logits, pre_bias_logits)
1123+
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) for z in (x, logits, pre_bias_logits)
11241124
)
11251125

11261126
# "Route" tokens within each shard.
@@ -1262,6 +1262,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12621262

12631263
wi_combine_scopes = self.config.wi_combine_scopes
12641264
wo_combine_scopes = self.config.wo_combine_scopes
1265+
# x * W_gate
12651266
layer_w0 = gmm_fn(
12661267
x,
12671268
w0,
@@ -1274,8 +1275,8 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12741275
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
12751276
if self.config.mlp_bias:
12761277
layer_w0 = layer_w0 + w0_bias
1277-
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
1278-
1278+
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1279+
# x * W_up
12791280
layer_w1 = gmm_fn(
12801281
x,
12811282
w1,
@@ -1288,9 +1289,10 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12881289
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
12891290
if self.config.mlp_bias:
12901291
layer_w1 = layer_w1 + w1_bias
1291-
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
1292+
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1293+
# multiplied result from W_gate and W_up before downward projection
12921294
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
1293-
1295+
# output of FFN
12941296
intermediate_output = gmm_fn(
12951297
intermediate_layer,
12961298
wo,
@@ -1435,10 +1437,8 @@ def reshape_and_update_weights(self, weights, indices):
14351437
# output of updated weights: (batch_size, seq_len, num_experts)
14361438
update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype)
14371439
index_update = (
1438-
self._maybe_shard_with_logical(
1439-
jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None)
1440-
),
1441-
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp_moe", None)),
1440+
self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None)),
1441+
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)),
14421442
indices,
14431443
)
14441444
weight_sharding = (
@@ -1664,9 +1664,7 @@ def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
16641664
einsum_op = jnp.einsum
16651665
return einsum_op
16661666

1667-
def maybe_all_gather_kernel_weight_in_expert_parallelism(
1668-
self, kernel: jax.Array, kernel_axes: Tuple[Optional[str], ...]
1669-
):
1667+
def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel: jax.Array, kernel_axes: Tuple[Optional[str], ...]):
16701668
"""All-gather kernel weight in expert parallelism if needed."""
16711669
if self.get_expert_parallelism_size() > 1:
16721670
# This will trigger all-gather using weight_dtype
@@ -1691,14 +1689,10 @@ def dense_matmul(
16911689
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
16921690
"""Dense matrix multiplication."""
16931691
# gate_logits: batch, length, expert
1694-
gate_logits = self._maybe_shard_with_logical(
1695-
gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
1696-
)
1692+
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None))
16971693
if self.config.model_name.startswith("deepseek3"):
16981694
# pre_bias_logits is None for non-DeepSeek v3 models
1699-
pre_bias_logits = self._maybe_shard_with_logical(
1700-
pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
1701-
)
1695+
pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, ("activation_batch", "activation_norm_length", None))
17021696
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs)
17031697
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
17041698
if is_llama4_decoder_layer:
@@ -1711,9 +1705,7 @@ def dense_matmul(
17111705
# Calculate load balance loss
17121706
if self.config.model_call_mode != "inference":
17131707
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
1714-
lb_loss = (
1715-
self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None
1716-
)
1708+
lb_loss = self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None
17171709
else:
17181710
lb_loss = None
17191711

@@ -1990,9 +1982,7 @@ def retrieve_quantized_weight(
19901982
# This is called only during tracing. This is to invoke creation of
19911983
# quantized tensor inside AqtEinsum. After jit, this will become no-op and
19921984
# will not affect performance.
1993-
_ = self.dense_matmul(
1994-
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
1995-
)
1985+
_ = self.dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias)
19961986

19971987
w0_kernel = self.variables["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qrhs"]["frozen"]
19981988
w1_kernel = self.variables["aqt"]["AqtEinsum_1"]["AqtDotGeneral_0"]["qrhs"]["frozen"]

src/maxtext/models/deepseek.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def mlp_op(self, x, deterministic, *args, **kwargs):
180180
def with_logical_constraint(self, x):
181181
return maybe_shard_with_logical(
182182
x,
183-
logical_axes=self.logical_axis_names,
183+
logical_axes=tuple(self.logical_axis_names),
184184
mesh=self.mesh,
185185
shard_mode=self.config.shard_mode,
186186
debug_sharding=self.config.debug_sharding,

src/maxtext/utils/gradient_accumulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _maybe_shard_with_name(inputs, sharding_names):
6868
return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding)
6969

7070
# For more efficient DP/ZeRO-1 + GA
71-
if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:
71+
if config.shard_mode == ShardMode.EXPLICIT and model.mesh.shape.get("data", 1) > 1:
7272
ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings)
7373
grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings)
7474
else:

0 commit comments

Comments
 (0)