Skip to content
82 changes: 32 additions & 50 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.flydsl.utils import is_flydsl_available
from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd
import pyhip

BLOCK_SIZE_M = 32

Expand Down Expand Up @@ -143,56 +142,39 @@ def fused_moe(
bias2=None,
splitk=0,
):
# Fast path for small batches for Qwen3.5 397B FP8 PTPC (Only used for decoding phase: batch size 1~32)
if os.environ.get('AITER_MOE_SMALL_BATCH', '0') == '1' and hidden_states.shape[0] <= 32 and hidden_states.dtype == torch.bfloat16 and expert_mask is None and activation == ActivationType.Silu and \
((quant_type == QuantType.per_Token and w1.dtype == torch.float8_e4m3fnuz)):

from pyhip.contrib.moe import moe_gemm_batch1, moe_gemm_batch, moe_2stage_splitk
fp8_ptpc = ((quant_type == QuantType.per_Token and w1.dtype == torch.float8_e4m3fnuz))
B = hidden_states.shape[0]
E, N1, K1 = w1.shape
N2, K2 = w2.shape[1], w2.shape[2]
TOPK = topk_ids.shape[1]
#print("B=%d, E=%d, N1=%d, K1=%d, N2=%d, K2=%d, TOPK=%d"%(B, E, N1, K1, N2, K2, TOPK))
assert N1 == 2 * K2
gemm1_out = torch.empty([B, TOPK, N1 // 2], dtype=hidden_states.dtype, device=hidden_states.device)
# print(f"================================================= batch size {B} ========================================================")
if B == 1:
assert N1 == 2 * K2
# Skip moe_sorting for batch size 1
cur_out = torch.zeros([1, N2], dtype=hidden_states.dtype, device=hidden_states.device)
moe_gemm_batch1([N1 // 32, TOPK],[256], w1.dtype, True, hidden_states.data_ptr(), w1.data_ptr(), gemm1_out.data_ptr(), topk_ids.data_ptr(), topk_weight.data_ptr(), w1_scale.data_ptr() if w1_scale is not None else 0, 1, N1, K1)
moe_gemm_batch1([N2 // 32, TOPK],[64], w1.dtype, False, gemm1_out.data_ptr(), w2.data_ptr(), cur_out.data_ptr(), topk_ids.data_ptr(), topk_weight.data_ptr(), w2_scale.data_ptr() if w2_scale is not None else 0, 1, N2, K2)

return cur_out
else:
BLOCK_M = 16
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, cur_out = moe_sorting(
topk_ids,
topk_weight,
E,
K1, # reduce dim is same with output dim
hidden_states.dtype,
BLOCK_M,
expert_mask, # None,
num_local_tokens,#None,
moe_sorting_dispatch_policy,
)
grid = sorted_expert_ids.shape[0]

if B * TOPK <= E:
grid = B * TOPK

moe_gemm_batch([N1 // 32, grid], [256],
w1.dtype, True,
hidden_states.data_ptr(), w1.data_ptr(), gemm1_out.data_ptr(), sorted_ids.data_ptr(), sorted_weights.data_ptr(), sorted_expert_ids.data_ptr(), num_valid_ids.data_ptr(), w1_scale.data_ptr() if w1_scale is not None else 0, B, N1, K1, TOPK)
BLOCK_TILE_SIZE_M = 16
BLOCK_TILE_SIZE_N = 64
moe_2stage_splitk([N2 // BLOCK_TILE_SIZE_N, grid], [64],
w1.dtype, TOPK, K2, N2, False, BLOCK_TILE_SIZE_M, BLOCK_TILE_SIZE_N,
gemm1_out.data_ptr(), w2.data_ptr(), cur_out.data_ptr(), sorted_ids.data_ptr(), sorted_weights.data_ptr(), sorted_expert_ids.data_ptr(), num_valid_ids.data_ptr(), w2_scale.data_ptr() if w2_scale is not None else 0, B, fp8_ptpc)
# Fast path for small batches for Qwen3.5 397B FP8 PTPC TP8 on gfx942 (only used for decoding phase: batch size 1~32)
# B=1~32, E=512, N1=256, K1=4096, N2=4096, K2=128, TOPK=10
if (os.environ.get("AITER_MOE_SMALL_BATCH", "0") == "1"
and 1 <= hidden_states.shape[0] <= 32
and hidden_states.dtype == torch.bfloat16
and expert_mask is None
and activation == ActivationType.Silu
and (quant_type == QuantType.per_Token and w1.dtype == torch.float8_e4m3fnuz)
and get_gfx() == "gfx942"
and topk_ids.shape[1] == 10
and w1.shape[1] == 256
and w1.shape[2] == 4096
and w2.shape[1] == 4096
and w2.shape[2] == 128
):
from aiter.fused_moe_ptpc_fp8 import fused_moe_ptpc_fp8
moe_buf = fused_moe_ptpc_fp8(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
activation,
quant_type,
w1_scale,
w2_scale,
expert_mask,
num_local_tokens,
moe_sorting_dispatch_policy,
)

return cur_out
if moe_buf is not None:
return moe_buf

if not block_size_M:
block_size_M = -1
Expand Down
256 changes: 256 additions & 0 deletions aiter/fused_moe_ptpc_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

"""Qwen3.5 397B MoE PTPC FP8 TP8 run with prebuilt hasco artifacts on gfx942:
``fused_moe()`` calls ``fused_moe_ptpc_fp8`` only when the following conditions are met:
- B=1~32, E=512, N1=256, K1=4096, N2=4096, K2=128, TOPK=10
- if (os.environ.get("AITER_MOE_SMALL_BATCH", "0") == "1"
and 1 <= hidden_states.shape[0] <= 32
and hidden_states.dtype == torch.bfloat16
and expert_mask is None
and activation == ActivationType.Silu
and (quant_type == QuantType.per_Token and w1.dtype == torch.float8_e4m3fnuz)
and get_gfx() == "gfx942"
and topk_ids.shape[1] == 10
and w1.shape[1] == 256
and w1.shape[2] == 4096
and w2.shape[1] == 4096
and w2.shape[2] == 128)
- Requires matching artifacts under ``hsa/gfx942/fmoe_ptpc_fp8/``, and loading via ``csrc.cpp_itfs.hsaco_tools.get_kernel``.
"""

import os
from typing import Any, Optional

import torch

import aiter
from aiter import ActivationType, QuantType, dtypes, logger
from aiter.jit.utils.chip_info import get_gfx
from aiter.fused_moe import moe_sorting
from csrc.cpp_itfs.hsaco_tools import get_kernel
from csrc.cpp_itfs.utils import AITER_CORE_DIR


def fused_moe_ptpc_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
activation: ActivationType,
quant_type: QuantType,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
expert_mask: Any,
num_local_tokens: Any,
moe_sorting_dispatch_policy: int,
) -> Optional[torch.Tensor]:
B = int(hidden_states.shape[0])
if not (1 <= B <= 32):
return None
if (
hidden_states.dtype != torch.bfloat16
or expert_mask is not None
or activation != ActivationType.Silu
):
return None
if not (
(quant_type == QuantType.per_Token and w1.dtype == torch.float8_e4m3fnuz and get_gfx() == "gfx942")
):
return None

E, N1, K1 = w1.shape
N2, K2 = w2.shape[1], w2.shape[2]
TOPK = topk_ids.shape[1]
fp8_ptpc = w1.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz)
num_CU = torch.cuda.get_device_properties(hidden_states.device).multi_processor_count
assert N1 == 2 * K2

topk_w_f32 = (
topk_weight
if topk_weight.dtype == torch.float32
else topk_weight.float()
)

gemm1_out = torch.empty(
[B, TOPK, N1 // 2],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if B == 1:
assert N1 == 2 * K2
try:
moe_gemm_batch1_gate = get_kernel(
f"{AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
f"moe_gemm_batch1-1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True:moe_gemm_batch1"
)
moe_gemm_batch1_down = get_kernel(
f"{AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
f"moe_gemm_batch1-1-weight_dtype=torch.float8_e4m3fnuz-with_silu=False:moe_gemm_batch1"
)
cur_out = torch.zeros(
[1, N2], dtype=hidden_states.dtype, device=hidden_states.device
)
moe_gemm_batch1_gate(
[N1 // 32, TOPK],
[256],
hidden_states,
w1,
gemm1_out,
topk_ids,
topk_w_f32,
w1_scale,
1,
N1,
K1,
)
moe_gemm_batch1_down(
[N2 // 32, TOPK],
[64],
gemm1_out,
w2,
cur_out,
topk_ids,
topk_w_f32,
w2_scale,
1,
N2,
K2,
)
except Exception as e:
msg = (
f"fused_moe_ptpc_fp8 (B=1): HSACO kernel load or launch failed: {e}. "
f"Check artifacts under {AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
)
logger.warning(
msg + "; fallback to default fused_moe_() instead."
)
return None
elif 2 <= B <= 32:
# Stage 1: Shared ``moe_sorting`` + ``moe_gemm_batch``;
# stage 2: Choose between ``moe_2stage_down_loopn`` and ``moe_2stage_splitk`` based on ``use_down_loopn`` condition.
BLOCK_M = 16
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, cur_out = moe_sorting(
topk_ids,
topk_weight,
E,
K1,
hidden_states.dtype,
BLOCK_M,
expert_mask,
num_local_tokens,
moe_sorting_dispatch_policy,
)
grid = int(sorted_expert_ids.shape[0])
if B * TOPK <= E:
grid = B * TOPK

try:
moe_gemm_batch = get_kernel(
f"{AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
f"moe_gemm_batch-1-weight_dtype=torch.float8_e4m3fnuz-with_silu=True:moe_gemm_batch"
)
moe_gemm_batch(
[N1 // 32, grid],
[256],
hidden_states,
w1,
gemm1_out,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
w1_scale,
B,
N1,
K1,
TOPK,
)
except Exception as e:
msg = (
f"fused_moe_ptpc_fp8 (B={B}): moe_gemm_batch HSACO kernel load or launch failed: {e}. "
f"Check artifacts under {AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
)
logger.warning(
msg + "; fallback to default fused_moe_() instead."
)
return None


BLOCK_N = 1024
use_down_loopn = (
fp8_ptpc
and (N2 // BLOCK_N) * grid >= num_CU
and N2 % BLOCK_N == 0
and 16 <= B <= 32
)

if use_down_loopn:
try:
moe_2stage_down_loopn = get_kernel(
f"{AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
f"moe_2stage_down_loopn-1-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-"
f"BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=16-fp8_ptpc=True-BLOCK_N=1024-"
f"atomic_write=False-STAGES=3:moe_2stage_down_loopn"
)
gemm2_out = torch.empty(
[B, TOPK, N2],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
moe_2stage_down_loopn(
[N2 // BLOCK_N, grid],
[256],
gemm1_out,
w2,
gemm2_out,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
w2_scale,
B,
)
cur_out = torch.sum(gemm2_out, dim=1)
except Exception as e:
msg = (
f"fused_moe_ptpc_fp8 (B={B}): moe_2stage_down_loopn HSACO kernel load or launch failed: {e}. "
f"Check artifacts under {AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
)
logger.warning(
msg + "; fallback to default fused_moe_() instead."
)
return None
else:
try:
moe_2stage_splitk = get_kernel(
f"{AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
f"moe_2stage_splitk-1-weight_dtype=torch.float8_e4m3fnuz-TOPK=10-K=128-N=4096-"
f"with_silu=False-BLOCK_TILE_SIZE_M=16-BLOCK_TILE_SIZE_N=64-fp8_ptpc=True:moe_2stage_splitk"
)
BLOCK_TILE_SIZE_N = 64
moe_2stage_splitk(
[N2 // BLOCK_TILE_SIZE_N, grid],
[64],
gemm1_out,
w2,
cur_out,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
w2_scale,
B,
)
except Exception as e:
msg = (
f"fused_moe_ptpc_fp8 (B={B}): moe_2stage_splitk HSACO kernel load or launch failed: {e}. "
f"Check artifacts under {AITER_CORE_DIR}/hsa/gfx942/fmoe_ptpc_fp8/"
)
logger.warning(
msg + "; fallback to default fused_moe_() instead."
)
return None

return cur_out
Loading