Skip to content

Commit ab79b0a

Browse files
committed
fix review & fix moe attribute error
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
1 parent 846bcde commit ab79b0a

3 files changed

Lines changed: 11 additions & 9 deletions

File tree

vllm/model_executor/layers/fused_moe/oracle/mxfp4.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class Mxfp4MoeBackend(Enum):
6565
MARLIN = "MARLIN"
6666
# ROCm AITER backends
6767
AITER_MXFP4_BF16 = "AITER_MXFP4_BF16" # W4A16: CK kernel
68+
# Keep the legacy name as an alias while the ROCm split backend rename settles.
69+
AITER = "AITER_MXFP4_BF16"
6870
AITER_MXFP4_FP8 = "AITER_MXFP4_FP8" # W4A8: triton kernel
6971
# Triton
7072
TRITON = "TRITON"
@@ -255,7 +257,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
255257
backend-level ``is_supported_config`` check filters by device capability).
256258
"""
257259
if current_platform.is_rocm():
258-
return [Mxfp4MoeBackend.AITER]
260+
return [Mxfp4MoeBackend.AITER_MXFP4_BF16]
259261
_AVAILABLE_BACKENDS = [
260262
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
261263
Mxfp4MoeBackend.DEEPGEMM_MXFP4,
@@ -555,7 +557,7 @@ def _return_or_raise(
555557
):
556558
priority_backends = [
557559
Mxfp4MoeBackend.TRITON_UNFUSED,
558-
Mxfp4MoeBackend.AITER,
560+
Mxfp4MoeBackend.AITER_MXFP4_BF16,
559561
]
560562
else:
561563
priority_backends = _get_priority_backends()
@@ -1269,7 +1271,7 @@ def convert_weight_to_mxfp4_moe_kernel_format(
12691271
w2_bias,
12701272
)
12711273

1272-
elif mxfp4_backend == Mxfp4MoeBackend.AITER:
1274+
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_BF16:
12731275
from vllm._aiter_ops import rocm_aiter_ops
12741276

12751277
if w13_bias is not None:

vllm/utils/deep_gemm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,7 @@ def tf32_hc_prenorm_gemm(
473473
"""
474474
_lazy_init()
475475
if _tf32_hc_prenorm_gemm_impl is None:
476-
out.zero_()
477-
sqrsum.zero_()
478-
out[0].copy_(torch.matmul(x.to(torch.float32), fn.t().to(torch.float32)))
479-
sqrsum[0].copy_(x.to(torch.float32).square().sum(dim=-1))
480-
return out
476+
return _missing()
481477
return _tf32_hc_prenorm_gemm_impl(
482478
x,
483479
fn,

vllm/v1/attention/ops/rocm_aiter_mla_sparse.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010

1111
from vllm.forward_context import get_forward_context
1212
from vllm.platforms import current_platform
13-
from vllm.platforms.rocm import _ON_GFX942
1413
from vllm.triton_utils import tl, triton
1514
from vllm.utils.torch_utils import LayerNameType
1615
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
1716
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
1817

18+
if current_platform.is_rocm():
19+
from vllm.platforms.rocm import _ON_GFX942
20+
else:
21+
_ON_GFX942 = False
22+
1923

2024
@triton.jit
2125
def _indexer_k_quant_and_cache_kernel(

0 commit comments

Comments
 (0)