Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion tests/kernels/moe/test_exllama_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tests.kernels.utils import torch_experts
from vllm._custom_ops import gptq_shuffle
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
int4_w4a16_moe_quant_config,
Expand Down Expand Up @@ -260,3 +260,131 @@ def test_exllama_moe_contiguous(
force_block_size_m=2,
)
torch.testing.assert_close(exllama_out, torch_output, atol=2e-2, rtol=0)


def _make_triton_wna16_moe_weights(
E: int,
K: int,
N: int,
group_size: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create fake GPTQ-packed MoE weights in Triton WNA16 format.

Mimics CompressedTensorsWNA16MoEMethod.process_weights_after_loading()
Triton fallback path: transpose [E, K/8, N] → [E, N, K/8] → uint8.

Returns (w_packed, scales, w_ref) where:
- w_packed: [E, N, K/2] uint8 (Triton MoE format)
- scales: [E, N, num_groups] fp16 (transposed)
- w_ref: [E, N, K] fp16 (torch_experts convention)
"""
all_packed = []
all_scales = []
all_ref = []

for _ in range(E):
w_fp = torch.randn(K, N, device=device, dtype=torch.float16) / 10.0
q_packed, scales, w_ref = _symmetric_quantize_4bit(w_fp, group_size)
# q_packed is [K/8, N] int32 — NO gptq_shuffle (Triton path skips it)
all_packed.append(q_packed)
all_scales.append(scales)
all_ref.append(w_ref.t()) # [N, K]

# Stack: [E, K/8, N] int32
w_packed = torch.stack(all_packed)
w_scales = torch.stack(all_scales) # [E, num_groups, N]
w_ref = torch.stack(all_ref) # [E, N, K]

# Triton path: transpose + reinterpret as uint8
w_packed = w_packed.transpose(1, 2).contiguous().view(torch.uint8)
w_scales = w_scales.transpose(1, 2).contiguous()

return w_packed, w_scales, w_ref


def _run_triton_wna16_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Build weights in Triton WNA16 format, run fused_experts and reference.

This exercises the code path triggered by VLLM_MOE_GPTQ_EXLLAMA=false
in CompressedTensorsWNA16MoEMethod.

Returns (triton_output, reference_output).
"""
torch.cuda.manual_seed(1)
device = torch.device("cuda")
group_size = GROUP_SIZE

assert k % group_size == 0
assert n % PACK_FACTOR == 0

w1_packed, w1_scales, w1_ref = _make_triton_wna16_moe_weights(
e, k, 2 * n, group_size, device
)
w2_packed, w2_scales, w2_ref = _make_triton_wna16_moe_weights(
e, n, k, group_size, device
)

hidden = torch.randn(m, k, device=device, dtype=torch.float16) / 10
scores = torch.randn(m, e, device=device, dtype=torch.float16)

topk_weights, topk_ids, _ = fused_topk(hidden, scores, topk, False)

quant_config = int4_w4a16_moe_quant_config(
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)

init_workspace_manager(device)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
torch_output = torch_experts(
hidden,
w1_ref,
w2_ref,
topk_weight=topk_weights,
topk_ids=topk_ids,
global_num_experts=e,
)

triton_output = fused_experts(
hidden,
w1_packed,
w2_packed,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
global_num_experts=e,
expert_map=None,
quant_config=quant_config,
)

return triton_output, torch_output


@pytest.mark.skipif(
not (current_platform.is_rocm() or current_platform.is_cuda()),
reason="Requires ROCm or CUDA",
)
@pytest.mark.parametrize("m", [1, 4, 16])
@pytest.mark.parametrize("n,k", [(256, 256), (512, 256)])
@pytest.mark.parametrize("e,topk", [(8, 2), (16, 4)])
def test_triton_wna16_moe(m: int, n: int, k: int, e: int, topk: int):
"""Test the Triton WNA16 MoE fallback path (VLLM_MOE_GPTQ_EXLLAMA=false).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test CompressedTensorsWNA16MoEMethod.apply()

This exercises the same code path as CompressedTensorsWNA16MoEMethod.apply()
when moe_mk is None — weights are transposed to [E, N, K/2] uint8 and run
through fused_experts() with quant_config.
"""
triton_out, torch_output = _run_triton_wna16_moe(m, n, k, e, topk)
torch.testing.assert_close(triton_out, torch_output, atol=2e-2, rtol=0)
23 changes: 22 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,22 @@ def _get_or_set_default() -> str:

logger = logging.getLogger(__name__)


def _is_rdna_for_moe_default() -> bool:
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this and the venv variable and do proper is_compatible checks in compressed_tensor_moe.py

"""Check if running on RDNA (gfx11/gfx12) for MoE kernel default.

On RDNA, the Triton WNA16 MoE kernel outperforms the Exllama MoE kernel
due to better handling of many-expert routing and avoiding atomicAdd
K-tiling overhead.
"""
try:
from vllm.platforms.rocm import on_gfx1x

return on_gfx1x()
except (ImportError, Exception):
return False


environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default),
Expand Down Expand Up @@ -943,8 +959,13 @@ def _get_or_set_default() -> str:
),
# Use exllama 4-bit kernel for MoE GPTQ instead of Triton.
# Requires exllama-native weight format [E, K/8, N] int32.
# Defaults to false on RDNA (gfx11/gfx12) where Triton MoE is faster.
"VLLM_MOE_GPTQ_EXLLAMA": lambda: (
os.getenv("VLLM_MOE_GPTQ_EXLLAMA", "true").lower() in ("true", "1")
os.getenv(
"VLLM_MOE_GPTQ_EXLLAMA",
"false" if _is_rdna_for_moe_default() else "true",
).lower()
in ("true", "1")
),
# Optional: enable external Oink custom ops (e.g., Blackwell RMSNorm).
# Disabled by default.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,8 @@ def __init__(
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
self.moe_mk: FusedMoEKernel | None = None
self.moe_quant_config: FusedMoEQuantConfig | None = None
# Extract properties from weight_quant
self.num_bits = weight_quant.num_bits
self.packed_factor = 32 // weight_quant.num_bits
Expand Down Expand Up @@ -1911,6 +1913,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
requires_grad=False,
)
layer.use_exllama_moe = False
self.moe_quant_config = self.get_fused_moe_quant_config(layer)

def _process_weights_awq_gemv(self, layer: torch.nn.Module) -> None:
"""AWQ GEMV MoE path: convert GPTQ [E, K/8, N] → AWQ [E, K, N/8]
Expand Down
Loading