@@ -384,7 +384,10 @@ def __init__(
384384 kernel_init = self .kernel_init ,
385385 kernel_axes = self .kernel_axes ,
386386 use_bias = self .config .routed_bias ,
387- score_func = self .config .routed_score_func ,
387+ # tpu-inference applies the score function in the fused_moe_gmm kernel,
388+ # so we don't apply it here to avoid redundant computation.
389+ # See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58.
390+ score_func = "" if self .config .attention == "vllm_rpa" else self .config .routed_score_func ,
388391 matmul_precision = self .config .matmul_precision ,
389392 shard_mode = config .shard_mode ,
390393 rngs = self .rngs ,
@@ -403,6 +406,27 @@ def __init__(
403406 self .wi_0 = jnp .zeros ((num_experts , self .config .emb_dim , intermediate_dim ))
404407 self .wi_1 = jnp .zeros ((num_experts , self .config .emb_dim , intermediate_dim ))
405408 self .wo = jnp .zeros ((num_experts , intermediate_dim , self .config .emb_dim ))
409+ elif self .config .prefuse_moe_weights and self .config .attention == "vllm_rpa" :
410+ self .wi = nnx .Param (
411+ self .kernel_init (
412+ self .rngs .params (),
413+ (num_experts , self .config .emb_dim , intermediate_dim * 2 ),
414+ weight_dtype ,
415+ kernel_in_axis ,
416+ kernel_out_axis ,
417+ ),
418+ sharding = self .wi_kernel_axes ,
419+ )
420+ self .wo = nnx .Param (
421+ self .kernel_init (
422+ self .rngs .params (),
423+ (self .num_experts , self .intermediate_dim , self .config .emb_dim ),
424+ self .weight_dtype ,
425+ kernel_in_axis ,
426+ kernel_out_axis ,
427+ ),
428+ sharding = self .wo_kernel_axes ,
429+ )
406430 else :
407431 self .wi_0 = nnx .Param (
408432 self .kernel_init (
@@ -1970,6 +1994,72 @@ def dense_matmul(
19701994 ).astype (self .dtype )
19711995 return output , lb_loss , bias_updates
19721996
1997+ def fused_moe_matmul (
1998+ self ,
1999+ inputs ,
2000+ gate_logits ,
2001+ wo_kernel ,
2002+ w0_kernel = None ,
2003+ w1_kernel = None ,
2004+ fused_kernel = None ,
2005+ ) -> tuple [jax .Array , None , None ]:
2006+ """Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
2007+
2008+ fused_moe_func handles routing, GMM, and weighted combination internally.
2009+ It does not compute lb_loss or bias_updates (inference-only).
2010+ """
2011+ try :
2012+ # pylint: disable=import-outside-toplevel
2013+ # pytype: disable=import-error
2014+ from tpu_inference .layers .common .fused_moe_gmm import fused_moe_func
2015+ except ImportError as e :
2016+ raise ImportError ("fused_moe_matmul requires the tpu-inference package." ) from e
2017+
2018+ # Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
2019+ batch_size , seq_len , emb_dim = inputs .shape
2020+ hidden_states = jnp .reshape (inputs , (batch_size * seq_len , emb_dim ))
2021+ gating_output = jnp .reshape (gate_logits , (batch_size * seq_len , self .num_experts ))
2022+
2023+ # Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
2024+ # fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
2025+ if fused_kernel is None :
2026+ fused_kernel = jnp .concatenate ([w0_kernel , w1_kernel ], axis = - 1 )
2027+
2028+ # Use expert parallelism if the expert axis has size > 1
2029+ use_ep = self .get_expert_parallelism_size () > 1
2030+
2031+ # Map MaxText config fields to fused_moe_func args
2032+ activation = self .config .mlp_activations [0 ] # e.g. "silu"
2033+ scoring_fn = self .config .routed_score_func if self .config .routed_score_func else "softmax"
2034+
2035+ # Check if the model architecture intrinsically renormalizes weights
2036+ renormalize = self .config .norm_topk_prob or (
2037+ self .config .decoder_block not in (ctypes .DecoderBlockType .LLAMA4 , ctypes .DecoderBlockType .GEMMA4 )
2038+ )
2039+
2040+ output_2d = fused_moe_func (
2041+ hidden_states = hidden_states ,
2042+ w1 = fused_kernel ,
2043+ w2 = wo_kernel ,
2044+ w1_scale = None ,
2045+ w2_scale = None ,
2046+ w1_bias = None ,
2047+ w2_bias = None ,
2048+ gating_output = gating_output ,
2049+ topk = self .num_experts_per_tok ,
2050+ renormalize = renormalize ,
2051+ mesh = self .mesh ,
2052+ use_ep = use_ep ,
2053+ activation = activation ,
2054+ scoring_fn = scoring_fn ,
2055+ sc_kernel_threshold = 16777216 ,
2056+ sc_kernel_col_chunk_size = 1024 ,
2057+ )
2058+
2059+ # Reshape output 2D [T, D] -> 3D [B, S, D]
2060+ output = jnp .reshape (output_2d , (batch_size , seq_len , emb_dim ))
2061+ return output , None , None
2062+
19732063 def retrieve_quantized_weight (
19742064 self ,
19752065 inputs ,
@@ -2008,10 +2098,17 @@ def __call__(
20082098 routing_inputs = inputs if gate_inputs is None else gate_inputs .astype (gate_dtype )
20092099 gate_logits , pre_bias_logits = self .gate (routing_inputs )
20102100
2011- w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2012- w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
20132101 wo_kernel = jnp .asarray (self .wo [...], self .dtype )
20142102
2103+ fused_kernel = None
2104+ w0_kernel = None
2105+ w1_kernel = None
2106+ if cfg .prefuse_moe_weights and cfg .attention == "vllm_rpa" :
2107+ fused_kernel = jnp .asarray (self .wi [...], self .dtype )
2108+ else :
2109+ w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2110+ w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
2111+
20152112 if self .per_expert_scale is not None :
20162113 wo_kernel = wo_kernel * jnp .asarray (self .per_expert_scale [...], self .dtype )[:, None , None ]
20172114
@@ -2022,7 +2119,12 @@ def __call__(
20222119 else :
20232120 w0_bias , w1_bias , wo_bias = None , None , None
20242121
2025- if cfg .sparse_matmul :
2122+ # vllm_rpa codepath uses fused_moe_func from tpu_inference for optimized inference.
2123+ if cfg .attention == "vllm_rpa" :
2124+ output , lb_loss , bias_updates = self .fused_moe_matmul (
2125+ inputs , gate_logits , wo_kernel , w0_kernel = w0_kernel , w1_kernel = w1_kernel , fused_kernel = fused_kernel
2126+ )
2127+ elif cfg .sparse_matmul :
20262128 if quantizations .in_serve_mode (self .quant ):
20272129 w0_kernel , w1_kernel , wo_kernel = self .retrieve_quantized_weight (
20282130 inputs ,
0 commit comments