Skip to content

Commit 3fc8a2b

Browse files
Merge pull request #3627 from AI-Hypercomputer:nicogrande/fused-moe-gmm
PiperOrigin-RevId: 903492693
2 parents 24a3c79 + d0a0744 commit 3fc8a2b

7 files changed

Lines changed: 485 additions & 21 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,19 @@ weight_dtype: bfloat16
3030
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
3232
['activation_batch', ['data']],
33-
['activation_batch_moe', []],
33+
['activation_batch_moe', ['data']],
3434
['activation_embed_and_logits_batch', ['data', 'expert']],
3535
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3636
['activation_heads', ['model', 'expert']],
3737
['activation_kv_heads', ['model', 'expert']],
3838
['activation_attn_length', []],
39-
['activation_length', ['data']],
40-
['activation_length_moe', ['data', 'expert']],
41-
['activation_length_moe', 'data'],
39+
['activation_length', []],
40+
['activation_length_moe', []],
4241
['activation_q_length', ['expert', 'attn_dp_expert']],
4342
['activation_attn_embed', 'model'],
43+
# Expert is missing explicitly from activation_embed despite using TP.
44+
# We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
45+
# due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
4446
['activation_embed', ['model', 'attn_dp']],
4547
['activation_embed_moe', ['model', 'attn_dp']],
4648
['activation_mlp', ['model', 'attn_dp']],
@@ -53,23 +55,21 @@ logical_axis_rules: [
5355
['activation_norm_length', []],
5456
['activation_norm_length_moe', []],
5557
['activation_exp', ['expert', 'attn_dp_expert']],
56-
['decode_batch', ['expert', 'attn_dp_expert']],
57-
['decode_batch_moe', []],
58+
['decode_batch', ['data']],
59+
['decode_batch_moe', ['data']],
5860
['decode_length', []],
5961
['mlp', ['model', 'attn_dp']],
6062
['mlp_moe', ['model', 'attn_dp']],
6163
['mlp_no_fsdp', ['model', 'attn_dp']],
6264
['vocab', ['model', 'attn_dp']],
63-
['heads', ['model']],
65+
# Expert is intended to act like TP for attention.
66+
# We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
67+
['heads', ['model', 'expert']],
6468
['q_heads', ['model', 'expert']],
6569
['kv_heads', ['model', 'expert']],
6670
['kv_head_dim', []],
6771
['kv', []],
68-
['embed', ['expert', 'attn_dp_expert']],
69-
['embed', ['attn_dp_expert']],
70-
['embed_vocab', ['expert', 'attn_dp_expert']],
71-
['embed_vocab', ['attn_dp_expert']],
72-
['embed_moe', []],
72+
['embed', []],
7373
['embed_moe', []],
7474
['embed_tensor_transpose', ['attn_dp', 'model']],
7575
['q_lora', ['expert', 'attn_dp_expert']],

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,11 @@ class MoEGeneral(BaseModel):
698698
False,
699699
description="Whether to cast inputs to fp32 to compute MoE gate logits for numerical stability.",
700700
)
701+
prefuse_moe_weights: bool = Field(
702+
False,
703+
description="Whether to pre-fuse MoE weights (w0 and w1) during initialization. "
704+
"This is useful for inference performance in vllm_rpa mode.",
705+
)
701706

702707

703708
class MoEKernels(BaseModel):

src/maxtext/inference/vllm_decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def decode_with_vllm(config: Config) -> None:
8282
"weight_dtype": "bfloat16",
8383
"allow_split_physical_axes": True,
8484
"debug_sharding": config.debug_sharding,
85+
"prefuse_moe_weights": config.prefuse_moe_weights,
8586
},
8687
"sharding": {
8788
"sharding_strategy": {

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
124124
# Model creation
125125
self.model: nnx.Module | None = None
126126

127+
# Indicates that the model handles its own sharding logic
128+
self._self_manages_sharding = True
129+
127130
# Handle dummy weight loading during initialization
128131
if vllm_config.load_config.load_format == "dummy":
129132
self.load_weights(rng_key)
@@ -161,8 +164,8 @@ def __call__(
161164
raise ValueError("Model must be an instance of type nnx.Module.")
162165

163166
# Ensure inputs are at least 2D with a batch dimension
164-
input_ids = jnp.atleast_2d(input_ids)
165-
input_positions = jnp.atleast_2d(attention_metadata.input_positions)
167+
input_ids = jnp.expand_dims(input_ids, axis=1)
168+
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)
166169

167170
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
168171
aux_hidden_states = []
@@ -233,7 +236,7 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
233236

234237
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
235238
# Reshape to (num_tokens, 1, hidden_dim) for decoder output head
236-
y = hidden_states[:, jnp.newaxis, :]
239+
y = jnp.expand_dims(hidden_states, axis=1)
237240

238241
# Compute logits using the MaxText decoder's output head
239242
logits = self.model.decoder.apply_output_head(self.model.token_embedder, y, True, self.model_mode)
@@ -250,8 +253,8 @@ def load_weights(self, rng_key: jax.Array) -> None:
250253
if self.model is not None:
251254
return
252255

253-
with self.mesh, nn.logical_axis_rules(""):
254-
model = model_creation_utils.from_pretrained(
256+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
257+
model, _ = model_creation_utils.create_nnx_model(
255258
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256259
)
257260
self.model = nnx.data(model)

src/maxtext/layers/moe.py

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,10 @@ def __init__(
388388
kernel_init=self.kernel_init,
389389
kernel_axes=self.kernel_axes,
390390
use_bias=self.config.routed_bias,
391-
score_func=self.config.routed_score_func,
391+
# tpu-inference applies the score function in the fused_moe_gmm kernel,
392+
# so we don't apply it here to avoid redundant computation.
393+
# See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58.
394+
score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func,
392395
matmul_precision=self.config.matmul_precision,
393396
shard_mode=config.shard_mode,
394397
rngs=self.rngs,
@@ -407,6 +410,27 @@ def __init__(
407410
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
408411
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
409412
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
413+
elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa":
414+
self.wi = nnx.Param(
415+
self.kernel_init(
416+
self.rngs.params(),
417+
(num_experts, self.moe_expert_input_dim, intermediate_dim * 2),
418+
weight_dtype,
419+
kernel_in_axis,
420+
kernel_out_axis,
421+
),
422+
sharding=self.wi_kernel_axes,
423+
)
424+
self.wo = nnx.Param(
425+
self.kernel_init(
426+
self.rngs.params(),
427+
(self.num_experts, self.intermediate_dim, self.moe_expert_input_dim),
428+
self.weight_dtype,
429+
kernel_in_axis,
430+
kernel_out_axis,
431+
),
432+
sharding=self.wo_kernel_axes,
433+
)
410434
else:
411435
self.wi_0 = nnx.Param(
412436
self.kernel_init(
@@ -2009,6 +2033,72 @@ def dense_matmul(
20092033
).astype(self.dtype)
20102034
return output, lb_loss, bias_updates
20112035

2036+
def fused_moe_matmul(
2037+
self,
2038+
inputs,
2039+
gate_logits,
2040+
wo_kernel,
2041+
w0_kernel=None,
2042+
w1_kernel=None,
2043+
fused_kernel=None,
2044+
) -> tuple[jax.Array, None, None]:
2045+
"""Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
2046+
2047+
fused_moe_func handles routing, GMM, and weighted combination internally.
2048+
It does not compute lb_loss or bias_updates (inference-only).
2049+
"""
2050+
try:
2051+
# pylint: disable=import-outside-toplevel
2052+
# pytype: disable=import-error
2053+
from tpu_inference.layers.common.fused_moe_gmm import fused_moe_func
2054+
except ImportError as e:
2055+
raise ImportError("fused_moe_matmul requires the tpu-inference package.") from e
2056+
2057+
# Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
2058+
batch_size, seq_len, emb_dim = inputs.shape
2059+
hidden_states = jnp.reshape(inputs, (batch_size * seq_len, emb_dim))
2060+
gating_output = jnp.reshape(gate_logits, (batch_size * seq_len, self.num_experts))
2061+
2062+
# Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
2063+
# fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
2064+
if fused_kernel is None:
2065+
fused_kernel = jnp.concatenate([w0_kernel, w1_kernel], axis=-1)
2066+
2067+
# Use expert parallelism if the expert axis has size > 1
2068+
use_ep = self.get_expert_parallelism_size() > 1
2069+
2070+
# Map MaxText config fields to fused_moe_func args
2071+
activation = self.config.mlp_activations[0] # e.g. "silu"
2072+
scoring_fn = self.config.routed_score_func if self.config.routed_score_func else "softmax"
2073+
2074+
# Check if the model architecture intrinsically renormalizes weights
2075+
renormalize = self.config.norm_topk_prob or (
2076+
self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4)
2077+
)
2078+
2079+
output_2d = fused_moe_func(
2080+
hidden_states=hidden_states,
2081+
w1=fused_kernel,
2082+
w2=wo_kernel,
2083+
w1_scale=None,
2084+
w2_scale=None,
2085+
w1_bias=None,
2086+
w2_bias=None,
2087+
gating_output=gating_output,
2088+
topk=self.num_experts_per_tok,
2089+
renormalize=renormalize,
2090+
mesh=self.mesh,
2091+
use_ep=use_ep,
2092+
activation=activation,
2093+
scoring_fn=scoring_fn,
2094+
sc_kernel_threshold=16777216,
2095+
sc_kernel_col_chunk_size=1024,
2096+
)
2097+
2098+
# Reshape output 2D [T, D] -> 3D [B, S, D]
2099+
output = jnp.reshape(output_2d, (batch_size, seq_len, emb_dim))
2100+
return output, None, None
2101+
20122102
def retrieve_quantized_weight(
20132103
self,
20142104
inputs,
@@ -2047,10 +2137,17 @@ def __call__(
20472137
routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype)
20482138
gate_logits, pre_bias_logits = self.gate(routing_inputs)
20492139

2050-
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2051-
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
20522140
wo_kernel = jnp.asarray(self.wo[...], self.dtype)
20532141

2142+
fused_kernel = None
2143+
w0_kernel = None
2144+
w1_kernel = None
2145+
if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa":
2146+
fused_kernel = jnp.asarray(self.wi[...], self.dtype)
2147+
else:
2148+
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2149+
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
2150+
20542151
if self.per_expert_scale is not None:
20552152
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
20562153

@@ -2061,7 +2158,12 @@ def __call__(
20612158
else:
20622159
w0_bias, w1_bias, wo_bias = None, None, None
20632160

2064-
if cfg.sparse_matmul:
2161+
# vllm_rpa codepath uses fused_moe_func from tpu_inference for optimized inference.
2162+
if cfg.attention == "vllm_rpa":
2163+
output, lb_loss, bias_updates = self.fused_moe_matmul(
2164+
inputs, gate_logits, wo_kernel, w0_kernel=w0_kernel, w1_kernel=w1_kernel, fused_kernel=fused_kernel
2165+
)
2166+
elif cfg.sparse_matmul:
20652167
if quantizations.in_serve_mode(self.quant):
20662168
w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight(
20672169
inputs,

src/maxtext/utils/model_creation_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import flax.linen as nn
2929
import jax
3030
import jax.numpy as jnp
31+
import numpy as np
3132
from jax.sharding import Mesh
3233
from maxtext.configs import pyconfig
3334
from maxtext.common.common_types import MODEL_MODE_TRAIN
@@ -507,6 +508,39 @@ def create_sharded_state():
507508
# Get the structure of checkpoint in `config.load_parameters_path`
508509
metadata = ckptr.metadata(config.load_parameters_path)
509510

511+
def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
512+
if not hasattr(target, "items") or not hasattr(meta_tree, "items"):
513+
return target
514+
new_target = {}
515+
for k, v in target.items():
516+
if k == "wi" and "wi" not in meta_tree and "wi_0" in meta_tree and "wi_1" in meta_tree:
517+
if not is_nnx:
518+
arr = v
519+
half_dim = arr.shape[-1] // 2
520+
new_target["wi_0"] = jax.ShapeDtypeStruct(
521+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
522+
)
523+
new_target["wi_1"] = jax.ShapeDtypeStruct(
524+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
525+
)
526+
else:
527+
arr = v["value"]
528+
half_dim = arr.shape[-1] // 2
529+
new_target["wi_0"] = {
530+
"value": jax.ShapeDtypeStruct(
531+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
532+
)
533+
}
534+
new_target["wi_1"] = {
535+
"value": jax.ShapeDtypeStruct(
536+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
537+
)
538+
}
539+
else:
540+
new_target[k] = _adjust_target_for_moe_fusion(v, meta_tree.get(k, {}), is_nnx)
541+
542+
return new_target
543+
510544
is_nnx_checkpoint = True
511545
if (
512546
"params" in metadata.item_metadata.tree.keys()
@@ -520,6 +554,10 @@ def create_sharded_state():
520554
is_leaf=lambda n: hasattr(n, "value"),
521555
)
522556

557+
target_for_restore = _adjust_target_for_moe_fusion(
558+
target_for_restore, metadata.item_metadata.tree["params"]["params"], False
559+
)
560+
523561
item_to_restore = {"params": {"params": target_for_restore}}
524562
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
525563
restore_args = {
@@ -538,6 +576,7 @@ def create_sharded_state():
538576
sharded_state,
539577
is_leaf=lambda n: isinstance(n, nnx.Variable),
540578
)
579+
target_for_restore = _adjust_target_for_moe_fusion(target_for_restore, metadata.item_metadata.tree, True)
541580
item_to_restore = target_for_restore
542581
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
543582
restore_args = _fix_restore_args_for_shape_mismatch(
@@ -577,6 +616,36 @@ def create_sharded_state():
577616
sharded_state,
578617
is_leaf=lambda n: isinstance(n, nnx.Variable),
579618
)
619+
620+
def to_dict(tree):
621+
if hasattr(tree, "items"):
622+
return {k: to_dict(v) for k, v in tree.items()}
623+
return tree
624+
625+
model_arrays = to_dict(model_arrays)
626+
checkpoint = to_dict(checkpoint)
627+
628+
def _fuse_moe_weights(ckpt_tree, model_arrays_tree):
629+
if not hasattr(ckpt_tree, "items") or not hasattr(model_arrays_tree, "items"):
630+
return ckpt_tree
631+
new_ckpt = {}
632+
for k, v in ckpt_tree.items():
633+
if k in ("wi_0", "wi_1") and "wi" in model_arrays_tree:
634+
continue
635+
new_ckpt[k] = _fuse_moe_weights(v, model_arrays_tree.get(k, {}))
636+
637+
if "wi" in model_arrays_tree and "wi_0" in ckpt_tree and "wi_1" in ckpt_tree:
638+
wi_0 = ckpt_tree["wi_0"]
639+
wi_1 = ckpt_tree["wi_1"]
640+
new_ckpt["wi"] = np.concatenate([wi_0, wi_1], axis=-1)
641+
642+
return new_ckpt
643+
644+
checkpoint = _fuse_moe_weights(checkpoint, model_arrays)
645+
# Release the raw restored buffers now that wi_0/wi_1 have been fused (if needed).
646+
# This prevents the replicated intermediate copies from persisting until function return.
647+
del restored
648+
580649
checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays)
581650
nnx.update(model, checkpoint)
582651

0 commit comments

Comments
 (0)