diff --git a/tests/kernels/moe/test_exllama_moe.py b/tests/kernels/moe/test_exllama_moe.py index b613c99f55ff..732186d1d3e3 100644 --- a/tests/kernels/moe/test_exllama_moe.py +++ b/tests/kernels/moe/test_exllama_moe.py @@ -18,7 +18,7 @@ from tests.kernels.utils import torch_experts from vllm._custom_ops import gptq_shuffle from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( int4_w4a16_moe_quant_config, @@ -260,3 +260,131 @@ def test_exllama_moe_contiguous( force_block_size_m=2, ) torch.testing.assert_close(exllama_out, torch_output, atol=2e-2, rtol=0) + + +def _make_triton_wna16_moe_weights( + E: int, + K: int, + N: int, + group_size: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Create fake GPTQ-packed MoE weights in Triton WNA16 format. + + Mimics CompressedTensorsWNA16MoEMethod.process_weights_after_loading() + Triton fallback path: transpose [E, K/8, N] → [E, N, K/8] → uint8. + + Returns (w_packed, scales, w_ref) where: + - w_packed: [E, N, K/2] uint8 (Triton MoE format) + - scales: [E, N, num_groups] fp16 (transposed) + - w_ref: [E, N, K] fp16 (torch_experts convention) + """ + all_packed = [] + all_scales = [] + all_ref = [] + + for _ in range(E): + w_fp = torch.randn(K, N, device=device, dtype=torch.float16) / 10.0 + q_packed, scales, w_ref = _symmetric_quantize_4bit(w_fp, group_size) + # q_packed is [K/8, N] int32 — NO gptq_shuffle (Triton path skips it) + all_packed.append(q_packed) + all_scales.append(scales) + all_ref.append(w_ref.t()) # [N, K] + + # Stack: [E, K/8, N] int32 + w_packed = torch.stack(all_packed) + w_scales = torch.stack(all_scales) # [E, num_groups, N] + w_ref = torch.stack(all_ref) # [E, N, K] + + # Triton path: transpose + reinterpret as uint8 + w_packed = w_packed.transpose(1, 2).contiguous().view(torch.uint8) + w_scales = w_scales.transpose(1, 2).contiguous() + + return w_packed, w_scales, w_ref + + +def _run_triton_wna16_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build weights in Triton WNA16 format, run fused_experts and reference. + + This exercises the code path triggered by VLLM_MOE_GPTQ_EXLLAMA=false + in CompressedTensorsWNA16MoEMethod. + + Returns (triton_output, reference_output). + """ + torch.cuda.manual_seed(1) + device = torch.device("cuda") + group_size = GROUP_SIZE + + assert k % group_size == 0 + assert n % PACK_FACTOR == 0 + + w1_packed, w1_scales, w1_ref = _make_triton_wna16_moe_weights( + e, k, 2 * n, group_size, device + ) + w2_packed, w2_scales, w2_ref = _make_triton_wna16_moe_weights( + e, n, k, group_size, device + ) + + hidden = torch.randn(m, k, device=device, dtype=torch.float16) / 10 + scores = torch.randn(m, e, device=device, dtype=torch.float16) + + topk_weights, topk_ids, _ = fused_topk(hidden, scores, topk, False) + + quant_config = int4_w4a16_moe_quant_config( + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=None, + w2_zp=None, + block_shape=[0, group_size], + ) + + init_workspace_manager(device) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + torch_output = torch_experts( + hidden, + w1_ref, + w2_ref, + topk_weight=topk_weights, + topk_ids=topk_ids, + global_num_experts=e, + ) + + triton_output = fused_experts( + hidden, + w1_packed, + w2_packed, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=MoEActivation.SILU, + apply_router_weight_on_input=False, + global_num_experts=e, + expert_map=None, + quant_config=quant_config, + ) + + return triton_output, torch_output + + +@pytest.mark.skipif( + not (current_platform.is_rocm() or current_platform.is_cuda()), + reason="Requires ROCm or CUDA", +) +@pytest.mark.parametrize("m", [1, 4, 16]) +@pytest.mark.parametrize("n,k", [(256, 256), (512, 256)]) +@pytest.mark.parametrize("e,topk", [(8, 2), (16, 4)]) +def test_triton_wna16_moe(m: int, n: int, k: int, e: int, topk: int): + """Test the Triton WNA16 MoE fallback path (VLLM_MOE_GPTQ_EXLLAMA=false). + + This exercises the same code path as CompressedTensorsWNA16MoEMethod.apply() + when moe_mk is None — weights are transposed to [E, N, K/2] uint8 and run + through fused_experts() with quant_config. + """ + triton_out, torch_output = _run_triton_wna16_moe(m, n, k, e, topk) + torch.testing.assert_close(triton_out, torch_output, atol=2e-2, rtol=0) diff --git a/vllm/envs.py b/vllm/envs.py index 334af9917363..bd89a77d1788 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -487,6 +487,22 @@ def _get_or_set_default() -> str: logger = logging.getLogger(__name__) + +def _is_rdna_for_moe_default() -> bool: + """Check if running on RDNA (gfx11/gfx12) for MoE kernel default. + + On RDNA, the Triton WNA16 MoE kernel outperforms the Exllama MoE kernel + due to better handling of many-expert routing and avoiding atomicAdd + K-tiling overhead. + """ + try: + from vllm.platforms.rocm import on_gfx1x + + return on_gfx1x() + except (ImportError, Exception): + return False + + environment_variables: dict[str, Callable[[], Any]] = { # ================== Installation Time Env Vars ================== # Target device of vLLM, supporting [cuda (by default), @@ -943,8 +959,13 @@ def _get_or_set_default() -> str: ), # Use exllama 4-bit kernel for MoE GPTQ instead of Triton. # Requires exllama-native weight format [E, K/8, N] int32. + # Defaults to false on RDNA (gfx11/gfx12) where Triton MoE is faster. "VLLM_MOE_GPTQ_EXLLAMA": lambda: ( - os.getenv("VLLM_MOE_GPTQ_EXLLAMA", "true").lower() in ("true", "1") + os.getenv( + "VLLM_MOE_GPTQ_EXLLAMA", + "false" if _is_rdna_for_moe_default() else "true", + ).lower() + in ("true", "1") ), # Optional: enable external Oink custom ops (e.g., Blackwell RMSNorm). # Disabled by default. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d01d05f87fde..0e614e5565b7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1691,6 +1691,8 @@ def __init__( super().__init__(moe) self.weight_quant = weight_quant self.input_quant = input_quant + self.moe_mk: FusedMoEKernel | None = None + self.moe_quant_config: FusedMoEQuantConfig | None = None # Extract properties from weight_quant self.num_bits = weight_quant.num_bits self.packed_factor = 32 // weight_quant.num_bits @@ -1911,6 +1913,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) layer.use_exllama_moe = False + self.moe_quant_config = self.get_fused_moe_quant_config(layer) def _process_weights_awq_gemv(self, layer: torch.nn.Module) -> None: """AWQ GEMV MoE path: convert GPTQ [E, K/8, N] → AWQ [E, K, N/8]