Skip to content

Commit 0b992ff

Browse files
authored
Add hybrid MoE kernel with wvSplitK int4 GEMM (#876)
* Add on-device fused MoE kernel for wvSplitK int4 GEMM Adds fused_moe_wvSplitK_int4_gemm that dispatches expert blocks via blockIdx.y on-device, eliminating host-side loops and GPU-CPU sync. Weights are in skinny layout [E, N, K//8] int32 (ExLlama shuffle). Key optimizations for RDNA 3.5 decode (batch=1): - Use all CUs per expert block for maximum bandwidth - YTILE=2 for N=1 decode (better occupancy than YTILE=1 or 4) - Reduced LDS allocation (16KB vs 64KB) for higher occupancy - Non-temporal weight loads to avoid L1 pollution - Scattered mode with sorted_token_ids for decode without pre-permutation Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> * Add HybridW4A16MoEExperts with Triton prefill + HIP decode Dispatch MoE INT4 GEMM based on batch size: Triton for prefill (M>5), HIP wvSplitK for decode (M<=5). Both read from the same shuffle-packed [E, N, K//8] int32 weights — no duplication. The Triton path adds use_shuffle_w4a16 to fused_moe_kernel_gptq_awq which unpacks ExLlama-shuffled int32 via tl.interleave, then extracts nibbles with shift+mask. Scales are [E, N, K//G], symmetric only. Weight processing converts GPTQ [E, K/8, N] to skinny [E, N, K//8] with ExLlama shuffle packing at load time. Enabled by default on ROCm via VLLM_MOE_HYBRID_W4A16=true. Qwen3-Omni-30B-A3B AWQ on Strix Halo (vs exllama baseline): TPOT: 14.51ms → 13.73ms (-5.4%) TTFT: 996ms → 841ms (-15.6%) Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> * Migrate hybrid W4A16 MoE to internal MK interface Build self.moe_kernel directly in process_weights_after_loading via maybe_make_prepare_finalize(allow_new_interface=True) so HybridW4A16MoEExperts runs on single-GPU deployments (no DP/EP), where the legacy select_gemm_impl path is bypassed by the upstream MoE refactor. Route apply() through self.moe_kernel.apply when the hybrid path is active; the legacy fused_experts call is preserved as the non-hybrid fallback. The dead select_gemm_impl branch for hybrid is removed. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> * Fix test: pass shared_experts=None to FusedMoEKernelModularImpl Base branch added a required shared_experts parameter to FusedMoEKernelModularImpl.__init__(). Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> --------- Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
1 parent 1e7b7be commit 0b992ff

10 files changed

Lines changed: 1420 additions & 142 deletions

File tree

.github/workflows/build-rocm-wheels.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ jobs:
227227
run: |
228228
python -m pytest -v --timeout=300 \
229229
tests/kernels/moe/test_exllama_moe.py \
230+
tests/kernels/moe/test_hybrid_w4a16_moe.py \
230231
tests/kernels/quantization/test_awq_gemv_moe.py \
231232
tests/kernels/quantization/test_hip_w4a16.py \
232233
tests/kernels/quantization/test_hybrid_w4a16_triton.py \

csrc/rocm/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ torch::Tensor wvSplitK_int4_g(const at::Tensor& in_a, const at::Tensor& in_b,
2020
const std::optional<at::Tensor>& in_bias,
2121
const int64_t CuCount, const int64_t group_size);
2222

23+
void fused_moe_wvSplitK_int4_gemm(torch::Tensor a, torch::Tensor w,
24+
torch::Tensor scales, torch::Tensor c,
25+
torch::Tensor expert_ids,
26+
int64_t block_size_m, int64_t CuCount,
27+
int64_t group_size, torch::Tensor zero_points,
28+
torch::Tensor sorted_token_ids,
29+
int64_t top_k);
30+
2331
#ifdef VLLM_SKINNY_GEMM_SWEEP
2432
torch::Tensor wvSplitK_sweep(const at::Tensor& in_a, const at::Tensor& in_b,
2533
const std::optional<at::Tensor>& in_bias,

csrc/rocm/skinny_gemms_int4.cu

Lines changed: 398 additions & 101 deletions
Large diffs are not rendered by default.

csrc/rocm/torch_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
4848
"int group_size) -> Tensor");
4949
rocm_ops.impl("wvSplitK_int4_g", torch::kCUDA, &wvSplitK_int4_g);
5050

51+
// Fused MoE wrapper around wvSplitK_int4_g: iterates expert runs in C++
52+
rocm_ops.def(
53+
"fused_moe_wvSplitK_int4_gemm(Tensor a, Tensor w, Tensor scales, "
54+
"Tensor c, Tensor expert_ids, int block_size_m, int CuCount, "
55+
"int group_size, Tensor zero_points, Tensor sorted_token_ids, "
56+
"int top_k) -> ()");
57+
rocm_ops.impl("fused_moe_wvSplitK_int4_gemm", torch::kCUDA,
58+
&fused_moe_wvSplitK_int4_gemm);
59+
5160
#ifdef VLLM_SKINNY_GEMM_SWEEP
5261
rocm_ops.def(
5362
"wvSplitK_int8_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, "
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Tests for HybridW4A16MoEExperts (Triton prefill + HIP decode).
4+
5+
Validates the hybrid MoE kernel by:
6+
1. Creating random fp16 MoE weights
7+
2. Quantizing them to symmetric 4-bit with group_size=32 or 128
8+
3. Packing into ExLlama shuffle format [E, N, K//8] int32
9+
4. Running HybridW4A16MoEExperts via FusedMoEModularKernel
10+
5. Comparing against torch_experts reference using dequantized weights
11+
12+
Tests exercise both paths:
13+
- Decode (M<=5): HIP wvSplitK_int4 kernel
14+
- Prefill (M>5): Triton fused_moe kernel with use_shuffle_w4a16
15+
"""
16+
17+
import pytest
18+
import torch
19+
20+
from tests.kernels.moe.utils import make_dummy_moe_config
21+
from tests.kernels.utils import torch_experts
22+
from vllm.config import VllmConfig, set_current_vllm_config
23+
from vllm.model_executor.kernels.linear.mixed_precision.hybrid_w4a16 import (
24+
pack_int4_exllama_shuffle,
25+
)
26+
from vllm.model_executor.layers.fused_moe import fused_topk
27+
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
28+
from vllm.model_executor.layers.fused_moe.config import (
29+
int4_w4a16_moe_quant_config,
30+
)
31+
from vllm.model_executor.layers.fused_moe.hybrid_w4a16_moe import (
32+
HybridW4A16MoEExperts,
33+
)
34+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
35+
FusedMoEKernelModularImpl,
36+
)
37+
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
38+
MoEPrepareAndFinalizeNoDPEPModular,
39+
)
40+
from vllm.platforms import current_platform
41+
from vllm.v1.worker.workspace import init_workspace_manager
42+
43+
NUM_BITS = 4
44+
PACK_FACTOR = 32 // NUM_BITS # 8 nibbles per int32
45+
46+
47+
def _symmetric_quantize_4bit_skinny(
48+
w: torch.Tensor,
49+
group_size: int,
50+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51+
"""Symmetric 4-bit quantization → skinny ExLlama format.
52+
53+
Input: w [K, N] fp16
54+
Returns:
55+
q_skinny: [N, K//8] int32 (ExLlama shuffle packed)
56+
scales: [N, K//G] fp16 (skinny layout)
57+
w_ref: [K, N] fp16 (dequantized reference)
58+
"""
59+
K, N = w.shape
60+
assert K % group_size == 0
61+
num_groups = K // group_size
62+
63+
w_grouped = w.reshape(num_groups, group_size, N)
64+
abs_max = w_grouped.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
65+
scales = abs_max / 7.0
66+
67+
# Quantize to unsigned [0, 15] with zero_point = 8
68+
w_q = torch.round(w_grouped / scales).clamp(-7, 7).int() + 8
69+
w_q = w_q.reshape(K, N)
70+
71+
# Dequantized reference
72+
w_ref = (
73+
((w_q.float() - 8.0).reshape(num_groups, group_size, N) * scales)
74+
.reshape(K, N)
75+
.half()
76+
)
77+
78+
# Pack into ExLlama shuffle: transpose to [N, K], pack to [N, K//8]
79+
w_q_uint4 = w_q.to(torch.uint8) # values in [0, 15]
80+
w_q_t = w_q_uint4.t().contiguous() # [N, K]
81+
q_skinny = pack_int4_exllama_shuffle(w_q_t) # [N, K//8] int32
82+
83+
# Scales: [num_groups, N] → [N, num_groups] (skinny layout)
84+
scales_skinny = scales.squeeze(1).t().contiguous() # [N, K//G]
85+
86+
return q_skinny, scales_skinny, w_ref
87+
88+
89+
def _make_hybrid_moe_weights(
90+
E: int,
91+
K: int,
92+
N: int,
93+
group_size: int,
94+
device: torch.device,
95+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
96+
"""Create fake skinny-packed MoE weights for E experts.
97+
98+
Returns (w_skinny, scales, w_ref) where:
99+
- w_skinny: [E, N, K//8] int32 (ExLlama shuffle packed)
100+
- scales: [E, N, K//G] fp16 (skinny layout)
101+
- w_ref: [E, N, K] fp16 (torch_experts convention)
102+
"""
103+
all_skinny = []
104+
all_scales = []
105+
all_ref = []
106+
107+
for _ in range(E):
108+
w_fp = torch.randn(K, N, device=device, dtype=torch.float16) / 10.0
109+
q_skinny, scales, w_ref = _symmetric_quantize_4bit_skinny(w_fp, group_size)
110+
all_skinny.append(q_skinny)
111+
all_scales.append(scales)
112+
all_ref.append(w_ref.t()) # transpose to [N, K] for torch_experts
113+
114+
w_skinny = torch.stack(all_skinny) # [E, N, K//8]
115+
w_scales = torch.stack(all_scales) # [E, N, K//G]
116+
w_ref = torch.stack(all_ref) # [E, N, K]
117+
118+
return w_skinny, w_scales, w_ref
119+
120+
121+
def _run_hybrid_moe(
122+
m: int,
123+
n: int,
124+
k: int,
125+
e: int,
126+
topk: int,
127+
group_size: int,
128+
force_triton: bool = False,
129+
force_hip: bool = False,
130+
) -> tuple[torch.Tensor, torch.Tensor]:
131+
"""Build weights, run HybridW4A16MoEExperts and torch_experts reference.
132+
133+
Args:
134+
force_triton: Force the Triton prefill path for all batch sizes.
135+
force_hip: Force the HIP wvSplitK path for all batch sizes.
136+
137+
Returns (hybrid_output, reference_output).
138+
"""
139+
torch.cuda.manual_seed(1)
140+
device = torch.device("cuda")
141+
142+
assert k % group_size == 0
143+
144+
# w1: gate+up projection [E, 2*N, K//8], ref [E, 2*N, K]
145+
w1_skinny, w1_scales, w1_ref = _make_hybrid_moe_weights(
146+
e, k, 2 * n, group_size, device
147+
)
148+
# w2: down projection [E, K, N//8], ref [E, K, N]
149+
w2_skinny, w2_scales, w2_ref = _make_hybrid_moe_weights(e, n, k, group_size, device)
150+
151+
hidden = torch.randn(m, k, device=device, dtype=torch.float16) / 10
152+
scores = torch.randn(m, e, device=device, dtype=torch.float16)
153+
154+
topk_weights, topk_ids, _ = fused_topk(hidden, scores, topk, False)
155+
156+
quant_config = int4_w4a16_moe_quant_config(
157+
w1_scale=w1_scales,
158+
w2_scale=w2_scales,
159+
w1_zp=None,
160+
w2_zp=None,
161+
block_shape=[0, group_size],
162+
)
163+
164+
moe_config = make_dummy_moe_config(
165+
num_experts=e,
166+
experts_per_token=topk,
167+
hidden_dim=k,
168+
intermediate_size_per_partition=n,
169+
in_dtype=torch.float16,
170+
)
171+
172+
experts = HybridW4A16MoEExperts(
173+
moe_config=moe_config,
174+
quant_config=quant_config,
175+
)
176+
177+
orig_threshold = HybridW4A16MoEExperts.MAX_SKINNY_BATCH_SIZE
178+
if force_triton:
179+
HybridW4A16MoEExperts.MAX_SKINNY_BATCH_SIZE = 0
180+
elif force_hip:
181+
HybridW4A16MoEExperts.MAX_SKINNY_BATCH_SIZE = 10000
182+
183+
try:
184+
mk = FusedMoEKernelModularImpl(
185+
fused_experts=experts,
186+
prepare_finalize=MoEPrepareAndFinalizeNoDPEPModular(),
187+
shared_experts=None,
188+
)
189+
190+
init_workspace_manager(device)
191+
vllm_config = VllmConfig()
192+
with set_current_vllm_config(vllm_config):
193+
torch_output = torch_experts(
194+
hidden,
195+
w1_ref,
196+
w2_ref,
197+
topk_weight=topk_weights,
198+
topk_ids=topk_ids,
199+
global_num_experts=e,
200+
)
201+
202+
hybrid_out = mk.apply(
203+
hidden_states=hidden,
204+
w1=w1_skinny,
205+
w2=w2_skinny,
206+
topk_weights=topk_weights,
207+
topk_ids=topk_ids,
208+
global_num_experts=e,
209+
expert_map=None,
210+
activation=MoEActivation.SILU,
211+
apply_router_weight_on_input=False,
212+
)
213+
finally:
214+
HybridW4A16MoEExperts.MAX_SKINNY_BATCH_SIZE = orig_threshold
215+
216+
return hybrid_out, torch_output
217+
218+
219+
@pytest.mark.skipif(
220+
not current_platform.is_rocm(),
221+
reason="HybridW4A16MoEExperts requires ROCm",
222+
)
223+
@pytest.mark.parametrize("m", [1, 4, 16, 64])
224+
@pytest.mark.parametrize("n,k", [(256, 256), (512, 256)])
225+
@pytest.mark.parametrize("e,topk", [(8, 2), (16, 4)])
226+
@pytest.mark.parametrize("group_size", [32, 128])
227+
def test_hybrid_w4a16_moe(m: int, n: int, k: int, e: int, topk: int, group_size: int):
228+
"""Test natural dispatch: HIP for decode (m<=5), Triton for prefill (m>5)."""
229+
hybrid_out, torch_output = _run_hybrid_moe(m, n, k, e, topk, group_size)
230+
torch.testing.assert_close(hybrid_out, torch_output, atol=2e-2, rtol=0)
231+
232+
233+
@pytest.mark.skipif(
234+
not current_platform.is_rocm(),
235+
reason="HybridW4A16MoEExperts requires ROCm",
236+
)
237+
@pytest.mark.parametrize("m", [1, 4, 16])
238+
@pytest.mark.parametrize("n,k", [(256, 256)])
239+
@pytest.mark.parametrize("e,topk", [(8, 2)])
240+
@pytest.mark.parametrize("group_size", [32])
241+
def test_hybrid_w4a16_moe_force_triton(
242+
m: int, n: int, k: int, e: int, topk: int, group_size: int
243+
):
244+
"""Force the Triton path for all batch sizes (including m=1)."""
245+
hybrid_out, torch_output = _run_hybrid_moe(
246+
m, n, k, e, topk, group_size, force_triton=True
247+
)
248+
torch.testing.assert_close(hybrid_out, torch_output, atol=2e-2, rtol=0)
249+
250+
251+
@pytest.mark.skipif(
252+
not current_platform.is_rocm(),
253+
reason="HybridW4A16MoEExperts requires ROCm",
254+
)
255+
@pytest.mark.parametrize("m", [1, 16, 64])
256+
@pytest.mark.parametrize("n,k", [(256, 256)])
257+
@pytest.mark.parametrize("e,topk", [(8, 2)])
258+
@pytest.mark.parametrize("group_size", [32])
259+
def test_hybrid_w4a16_moe_force_hip(
260+
m: int, n: int, k: int, e: int, topk: int, group_size: int
261+
):
262+
"""Force the HIP wvSplitK path for all batch sizes (including m=64)."""
263+
hybrid_out, torch_output = _run_hybrid_moe(
264+
m, n, k, e, topk, group_size, force_hip=True
265+
)
266+
torch.testing.assert_close(hybrid_out, torch_output, atol=2e-2, rtol=0)

vllm/_custom_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,6 +2713,38 @@ def _wvSplitK_int4_g_fake(
27132713
return torch.empty((N, M), dtype=in_b.dtype, device=in_b.device)
27142714

27152715

2716+
def fused_moe_wvSplitK_int4_gemm(
2717+
a: torch.Tensor,
2718+
w: torch.Tensor,
2719+
scales: torch.Tensor,
2720+
c: torch.Tensor,
2721+
expert_ids: torch.Tensor,
2722+
block_size_m: int,
2723+
cu_count: int,
2724+
group_size: int,
2725+
zero_points: torch.Tensor | None = None,
2726+
sorted_token_ids: torch.Tensor | None = None,
2727+
top_k: int = 1,
2728+
) -> None:
2729+
if zero_points is None:
2730+
zero_points = torch.empty(0, dtype=scales.dtype, device=a.device)
2731+
if sorted_token_ids is None:
2732+
sorted_token_ids = torch.empty(0, dtype=torch.int32, device=a.device)
2733+
torch.ops._rocm_C.fused_moe_wvSplitK_int4_gemm(
2734+
a,
2735+
w,
2736+
scales,
2737+
c,
2738+
expert_ids,
2739+
block_size_m,
2740+
cu_count,
2741+
group_size,
2742+
zero_points,
2743+
sorted_token_ids,
2744+
top_k,
2745+
)
2746+
2747+
27162748
def wvSplitK_int4g_sweep(
27172749
a: torch.Tensor,
27182750
b: torch.Tensor,

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
VLLM_USE_OINK_OPS: bool = False
114114
VLLM_MOE_AWQ_GEMV_HIP: bool = False
115115
VLLM_MOE_GPTQ_EXLLAMA: bool = False
116+
VLLM_MOE_HYBRID_W4A16: bool = False
116117
VLLM_ROCM_USE_MOE_WNA16_CUDA_KERNEL: bool = False
117118
VLLM_ROCM_USE_AITER: bool = False
118119
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
@@ -1003,6 +1004,11 @@ def _get_or_set_default() -> str:
10031004
"VLLM_MOE_AWQ_GEMV_HIP": lambda: (
10041005
os.getenv("VLLM_MOE_AWQ_GEMV_HIP", "false").lower() in ("true", "1")
10051006
),
1007+
# Use hybrid W4A16 (HIP skinny + Triton) kernel for MoE on ROCm.
1008+
# Converts weights to skinny layout [E, N, K//8] int32 (ExLlama shuffle).
1009+
"VLLM_MOE_HYBRID_W4A16": lambda: (
1010+
os.getenv("VLLM_MOE_HYBRID_W4A16", "true").lower() in ("true", "1")
1011+
),
10061012
# Use exllama 4-bit kernel for MoE GPTQ instead of Triton.
10071013
# Requires exllama-native weight format [E, K/8, N] int32.
10081014
"VLLM_MOE_GPTQ_EXLLAMA": lambda: (

0 commit comments

Comments
 (0)