Skip to content

Commit cfd608e

Browse files
committed
[ROCm][gpt-oss] Pass GateMode.INTERLEAVE for MXFP4 W4A16 fused MoE
The MXFP4 W4A16 weight-load path in oracle/mxfp4.py uses shuffle_weight_a16w4 (is_guinterleave=True), which interleaves gate/up columns within each weight tile. The CK/FlyDSL MoE kernels in aiter must be told this via gate_mode=GateMode.INTERLEAVE so they decode the gate/up packing correctly. Without the explicit gate_mode, aiter defaults to SEPARATED and (since ROCm/aiter#3123) dispatches the (SEPARATED + Swiglu + per_1x32 + fp4x2) case to a path that returns garbage for shuffled weights or crashes during CK2stages JIT for the unshuffled Quark variant (amd/gpt-oss-20b-w-mxfp4-a-bf16). This was the root cause of ROCM-25517 (gpt-oss-120b W4A16 gsm8k acc = 0) and ROCM-25478 (gpt-oss-20b Quark JIT crash). Other paths are unaffected: - FP8 W8A8 (DeepSeek-V4-Pro, DeepSeek-V3.2): shuffled with quark_ocp_mx.py:shuffle_weight(layout=(16,16)) — non-interleaved. use_mxfp4_w4a16 is False, default SEPARATED preserved. - MXFP4 W4A4 (amd/DeepSeek-R1-0528-MXFP4): shuffled via rocm_aiter_ops.shuffle_weights — non-interleaved. use_mxfp4_w4a16 is False, default SEPARATED preserved. The gate_mode kwarg was added to aiter.fused_moe in ROCm/aiter#3123 (aiter>=0.1.14). To stay compatible with older aiter shipping with vllm (e.g. aiter 0.1.13.post1 in the vllm-rocm:nightly image), we probe the aiter signature and drop the kwarg when unsupported — pre-vllm-project#3123 aiter tolerated the implicit SEPARATED default for interleave-shuffled weights, so dropping the kwarg is safe there. GateMode itself only exists on aiter>=0.1.14 and is imported under try/except for the same reason. Validation on MI355X (gfx950): vllm@main + aiter@main (6aeba41) openai/gpt-oss-120b W4A16 gsm8k: TP=1: 0.000 -> 0.905 TP=8: 0.000 -> 0.905 vllm@main + aiter@main amd/gpt-oss-20b-w-mxfp4-a-bf16 TP=2 enforce-eager: CK2stages JIT crash -> serves cleanly vllm-rocm:nightly + aiter 0.1.13.post1 openai/gpt-oss-120b W4A16 gsm8k: TP=1: 0.910 (backward-compat — gate_mode kwarg silently dropped) vllm-rocm:v0.22.0 + aiter@main openai/gpt-oss-120b W4A16 gsm8k: TP=1: 0.895 amd/gpt-oss120b-w-mxfp4-a-fp8 W4A8 (this PR composes with vllm-project#44804): TP=8 mc=1=326, mc=8=2087, mc=32=6523, mc=64=11610 tok/s Reference: sgl-project/sglang#25580 (sglang's equivalent fix). Recommended by aiter maintainer (XiaobingZhang) on ROCm/aiter#3586. Signed-off-by: Rohan Potdar <rohan.potdar@amd.com>
1 parent 2ed0a96 commit cfd608e

2 files changed

Lines changed: 37 additions & 0 deletions

File tree

vllm/_aiter_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _rocm_aiter_fused_moe_impl(
152152
output_dtype: torch.dtype | None = None,
153153
hidden_pad: int = 0,
154154
intermediate_pad: int = 0,
155+
gate_mode: str = "",
155156
bias1: torch.Tensor | None = None,
156157
bias2: torch.Tensor | None = None,
157158
moe_sorting_dispatch_policy: int = 0,
@@ -162,6 +163,11 @@ def _rocm_aiter_fused_moe_impl(
162163
activation = ActivationType(activation_method)
163164
quant_type = QuantType(quant_method)
164165

166+
extra_kwargs: dict = {}
167+
# `gate_mode` was added to aiter.fused_moe in #3123 (aiter>=0.1.14).
168+
if gate_mode and rocm_aiter_ops.fused_moe_supports_gate_mode():
169+
extra_kwargs["gate_mode"] = gate_mode
170+
165171
return fused_moe(
166172
hidden_states,
167173
w1,
@@ -183,6 +189,7 @@ def _rocm_aiter_fused_moe_impl(
183189
bias1=bias1,
184190
bias2=bias2,
185191
moe_sorting_dispatch_policy=moe_sorting_dispatch_policy,
192+
**extra_kwargs,
186193
)
187194

188195

@@ -204,6 +211,7 @@ def _rocm_aiter_fused_moe_fake(
204211
output_dtype: torch.dtype | None = None,
205212
hidden_pad: int = 0,
206213
intermediate_pad: int = 0,
214+
gate_mode: str = "",
207215
bias1: torch.Tensor | None = None,
208216
bias2: torch.Tensor | None = None,
209217
moe_sorting_dispatch_policy: int = 0,
@@ -1643,6 +1651,20 @@ def are_gdn_triton_kernels_available(cls) -> bool:
16431651
except (ImportError, ModuleNotFoundError):
16441652
return False
16451653

1654+
@classmethod
1655+
@if_aiter_supported
1656+
@functools.cache
1657+
def fused_moe_supports_gate_mode(cls) -> bool:
1658+
"""Probe whether the installed aiter.fused_moe accepts `gate_mode`.
1659+
1660+
Added in aiter#3123 (>=0.1.14). Builds with older aiter must omit the kwarg.
1661+
"""
1662+
import inspect
1663+
1664+
from aiter.fused_moe import fused_moe
1665+
1666+
return "gate_mode" in inspect.signature(fused_moe).parameters
1667+
16461668
@staticmethod
16471669
@if_aiter_supported
16481670
def register_ops_once() -> None:
@@ -1976,6 +1998,7 @@ def fused_moe(
19761998
output_dtype: torch.dtype | None = None,
19771999
hidden_pad: int = 0,
19782000
intermediate_pad: int = 0,
2001+
gate_mode: str = "",
19792002
bias1: torch.Tensor | None = None,
19802003
bias2: torch.Tensor | None = None,
19812004
moe_sorting_dispatch_policy: int = 0,
@@ -1998,6 +2021,7 @@ def fused_moe(
19982021
output_dtype,
19992022
hidden_pad,
20002023
intermediate_pad,
2024+
gate_mode,
20012025
bias1,
20022026
bias2,
20032027
moe_sorting_dispatch_policy,

vllm/model_executor/layers/fused_moe/experts/rocm_aiter_moe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,18 @@ def rocm_aiter_fused_experts(
341341
- moe_config.intermediate_size_per_partition_unpadded
342342
)
343343

344+
# MXFP4 W4A16 weights are interleave-shuffled in oracle/mxfp4.py;
345+
# match with GateMode.INTERLEAVE or aiter#3123 dispatch returns
346+
# garbage / fails JIT.
347+
gate_mode = ""
348+
if quant_config.use_mxfp4_w4a16:
349+
try:
350+
from aiter.ops.flydsl.moe_common import GateMode
351+
352+
gate_mode = GateMode.INTERLEAVE.value
353+
except ImportError:
354+
pass
355+
344356
return rocm_aiter_ops.fused_moe(
345357
hidden_states,
346358
w1,
@@ -359,6 +371,7 @@ def rocm_aiter_fused_experts(
359371
output_dtype=output_dtype,
360372
hidden_pad=hidden_pad // 128 * 128,
361373
intermediate_pad=intermediate_pad // 64 * 64 * 2,
374+
gate_mode=gate_mode,
362375
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
363376
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
364377
moe_sorting_dispatch_policy=moe_sorting_dispatch_policy,

0 commit comments

Comments
 (0)