@@ -388,7 +388,10 @@ def __init__(
388388 kernel_init = self .kernel_init ,
389389 kernel_axes = self .kernel_axes ,
390390 use_bias = self .config .routed_bias ,
391- score_func = self .config .routed_score_func ,
391+ # tpu-inference applies the score function in the fused_moe_gmm kernel,
392+ # so we don't apply it here to avoid redundant computation.
393+ # See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58.
394+ score_func = "" if self .config .attention == "vllm_rpa" else self .config .routed_score_func ,
392395 matmul_precision = self .config .matmul_precision ,
393396 shard_mode = config .shard_mode ,
394397 rngs = self .rngs ,
@@ -407,6 +410,27 @@ def __init__(
407410 self .wi_0 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
408411 self .wi_1 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
409412 self .wo = jnp .zeros ((num_experts , intermediate_dim , self .moe_expert_input_dim ))
413+ elif self .config .prefuse_moe_weights and self .config .attention == "vllm_rpa" :
414+ self .wi = nnx .Param (
415+ self .kernel_init (
416+ self .rngs .params (),
417+ (num_experts , self .moe_expert_input_dim , intermediate_dim * 2 ),
418+ weight_dtype ,
419+ kernel_in_axis ,
420+ kernel_out_axis ,
421+ ),
422+ sharding = self .wi_kernel_axes ,
423+ )
424+ self .wo = nnx .Param (
425+ self .kernel_init (
426+ self .rngs .params (),
427+ (self .num_experts , self .intermediate_dim , self .moe_expert_input_dim ),
428+ self .weight_dtype ,
429+ kernel_in_axis ,
430+ kernel_out_axis ,
431+ ),
432+ sharding = self .wo_kernel_axes ,
433+ )
410434 else :
411435 self .wi_0 = nnx .Param (
412436 self .kernel_init (
@@ -2009,6 +2033,72 @@ def dense_matmul(
20092033 ).astype (self .dtype )
20102034 return output , lb_loss , bias_updates
20112035
2036+ def fused_moe_matmul (
2037+ self ,
2038+ inputs ,
2039+ gate_logits ,
2040+ wo_kernel ,
2041+ w0_kernel = None ,
2042+ w1_kernel = None ,
2043+ fused_kernel = None ,
2044+ ) -> tuple [jax .Array , None , None ]:
2045+ """Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
2046+
2047+ fused_moe_func handles routing, GMM, and weighted combination internally.
2048+ It does not compute lb_loss or bias_updates (inference-only).
2049+ """
2050+ try :
2051+ # pylint: disable=import-outside-toplevel
2052+ # pytype: disable=import-error
2053+ from tpu_inference .layers .common .fused_moe_gmm import fused_moe_func
2054+ except ImportError as e :
2055+ raise ImportError ("fused_moe_matmul requires the tpu-inference package." ) from e
2056+
2057+ # Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
2058+ batch_size , seq_len , emb_dim = inputs .shape
2059+ hidden_states = jnp .reshape (inputs , (batch_size * seq_len , emb_dim ))
2060+ gating_output = jnp .reshape (gate_logits , (batch_size * seq_len , self .num_experts ))
2061+
2062+ # Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
2063+ # fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
2064+ if fused_kernel is None :
2065+ fused_kernel = jnp .concatenate ([w0_kernel , w1_kernel ], axis = - 1 )
2066+
2067+ # Use expert parallelism if the expert axis has size > 1
2068+ use_ep = self .get_expert_parallelism_size () > 1
2069+
2070+ # Map MaxText config fields to fused_moe_func args
2071+ activation = self .config .mlp_activations [0 ] # e.g. "silu"
2072+ scoring_fn = self .config .routed_score_func if self .config .routed_score_func else "softmax"
2073+
2074+ # Check if the model architecture intrinsically renormalizes weights
2075+ renormalize = self .config .norm_topk_prob or (
2076+ self .config .decoder_block not in (ctypes .DecoderBlockType .LLAMA4 , ctypes .DecoderBlockType .GEMMA4 )
2077+ )
2078+
2079+ output_2d = fused_moe_func (
2080+ hidden_states = hidden_states ,
2081+ w1 = fused_kernel ,
2082+ w2 = wo_kernel ,
2083+ w1_scale = None ,
2084+ w2_scale = None ,
2085+ w1_bias = None ,
2086+ w2_bias = None ,
2087+ gating_output = gating_output ,
2088+ topk = self .num_experts_per_tok ,
2089+ renormalize = renormalize ,
2090+ mesh = self .mesh ,
2091+ use_ep = use_ep ,
2092+ activation = activation ,
2093+ scoring_fn = scoring_fn ,
2094+ sc_kernel_threshold = 16777216 ,
2095+ sc_kernel_col_chunk_size = 1024 ,
2096+ )
2097+
2098+ # Reshape output 2D [T, D] -> 3D [B, S, D]
2099+ output = jnp .reshape (output_2d , (batch_size , seq_len , emb_dim ))
2100+ return output , None , None
2101+
20122102 def retrieve_quantized_weight (
20132103 self ,
20142104 inputs ,
@@ -2047,10 +2137,17 @@ def __call__(
20472137 routing_inputs = inputs if gate_inputs is None else gate_inputs .astype (gate_dtype )
20482138 gate_logits , pre_bias_logits = self .gate (routing_inputs )
20492139
2050- w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2051- w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
20522140 wo_kernel = jnp .asarray (self .wo [...], self .dtype )
20532141
2142+ fused_kernel = None
2143+ w0_kernel = None
2144+ w1_kernel = None
2145+ if cfg .prefuse_moe_weights and cfg .attention == "vllm_rpa" :
2146+ fused_kernel = jnp .asarray (self .wi [...], self .dtype )
2147+ else :
2148+ w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2149+ w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
2150+
20542151 if self .per_expert_scale is not None :
20552152 wo_kernel = wo_kernel * jnp .asarray (self .per_expert_scale [...], self .dtype )[:, None , None ]
20562153
@@ -2061,7 +2158,12 @@ def __call__(
20612158 else :
20622159 w0_bias , w1_bias , wo_bias = None , None , None
20632160
2064- if cfg .sparse_matmul :
2161+ # vllm_rpa codepath uses fused_moe_func from tpu_inference for optimized inference.
2162+ if cfg .attention == "vllm_rpa" :
2163+ output , lb_loss , bias_updates = self .fused_moe_matmul (
2164+ inputs , gate_logits , wo_kernel , w0_kernel = w0_kernel , w1_kernel = w1_kernel , fused_kernel = fused_kernel
2165+ )
2166+ elif cfg .sparse_matmul :
20652167 if quantizations .in_serve_mode (self .quant ):
20662168 w0_kernel , w1_kernel , wo_kernel = self .retrieve_quantized_weight (
20672169 inputs ,
0 commit comments