Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 65 additions & 37 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.flydsl.utils import is_flydsl_available
from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd
from aiter.fused_moe_pyhip_hsaco import MOE_HSACO_SMALL_BATCH_MAX_B

BLOCK_SIZE_M = 32

Expand Down Expand Up @@ -142,44 +143,32 @@ def fused_moe(
bias2=None,
splitk=0,
):
# fast path for small batches
if os.environ.get('AITER_MOE_SMALL_BATCH', '0') == '1' and hidden_states.shape[0] <= 16 and hidden_states.dtype == torch.bfloat16 and expert_mask is None and activation == ActivationType.Silu and \
((quant_type == QuantType.No and w1.dtype == torch.bfloat16) or (quant_type == QuantType.per_Token and w1.dtype == torch.float8_e4m3fnuz)):
B = hidden_states.shape[0]
E, N1, K1 = w1.shape
N2, K2 = w2.shape[1], w2.shape[2]
TOPK = topk_ids.shape[1]
assert N1 == 2 * K2
gemm1_out = torch.empty([B, TOPK, N1 // 2], dtype=hidden_states.dtype, device=hidden_states.device)
if B == 1:
assert N1 == 2 * K2
gemm2_out = torch.zeros([1, N2], dtype=hidden_states.dtype, device=hidden_states.device)
pyhip.kernels.moe.moe_gemm_batch1([N1 // 32, TOPK],[256], w1.dtype, K1, N1, True, hidden_states.data_ptr(), w1.data_ptr(), gemm1_out.data_ptr(), topk_ids.data_ptr(), topk_weight.data_ptr(), w1_scale.data_ptr() if w1_scale is not None else 0, 1)
pyhip.kernels.moe.moe_gemm_batch1([N2 // 32, TOPK],[64], w1.dtype, K2, N2, False, gemm1_out.data_ptr(), w2.data_ptr(), gemm2_out.data_ptr(), topk_ids.data_ptr(), topk_weight.data_ptr(), w2_scale.data_ptr() if w2_scale is not None else 0, 1)
return gemm2_out
else:
BLOCK_M = 16
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting(
topk_ids,
topk_weight,
E,
K1, # reduce dim is same with output dim
hidden_states.dtype,
BLOCK_M,
expert_mask,
num_local_tokens,
moe_sorting_dispatch_policy,
)
###### ONLY ALLOW TO CHANGE THIS PART STARTING FROM HERE ######
# HSACO small batch: prebuilt ``.co`` only for ``1 <= B <= MOE_HSACO_SMALL_BATCH_MAX_B``.
if (
os.environ.get("AITER_MOE_SMALL_BATCH", "0") == "1"
and 1 <= int(hidden_states.shape[0]) <= MOE_HSACO_SMALL_BATCH_MAX_B
):
from aiter.fused_moe_pyhip_hsaco import run_moe_small_batch_hsaco

pyhip.kernels.moe.moe_gemm_batch([N1 // 32, sorted_expert_ids.shape[0]], [256],
w1.dtype, TOPK, K1, N1, True,
hidden_states.data_ptr(), w1.data_ptr(), gemm1_out.data_ptr(), sorted_ids.data_ptr(), sorted_weights.data_ptr(), sorted_expert_ids.data_ptr(), num_valid_ids.data_ptr(), w1_scale.data_ptr() if w1_scale is not None else 0, B)
pyhip.kernels.moe.moe_gemm_batch([N2 // 32, sorted_expert_ids.shape[0]], [64],
w1.dtype, TOPK, K2, N2, False,
gemm1_out.data_ptr(), w2.data_ptr(), moe_buf.data_ptr(), sorted_ids.data_ptr(), sorted_weights.data_ptr(), sorted_expert_ids.data_ptr(), num_valid_ids.data_ptr(), w2_scale.data_ptr() if w2_scale is not None else 0, B)
moe_buf = run_moe_small_batch_hsaco(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
activation,
quant_type,
w1_scale,
w2_scale,
expert_mask,
num_local_tokens,
moe_sorting_dispatch_policy,
)

return moe_buf
return moe_buf

###### ONLY ALLOW TO CHANGE THIS PART ENDING HERE ######
if not block_size_M:
block_size_M = -1
return fused_moe_(
Expand Down Expand Up @@ -333,6 +322,27 @@ def fused_moe_(
# Ensure block_size_M is int (metadata.block_m from CSV may be float)
if block_size_M is not None:
block_size_M = int(block_size_M)
if (
os.environ.get("AITER_MOE_SMALL_BATCH", "0") == "1"
and M == 1
and expert_mask is None
and quant_type == QuantType.per_Token
and activation == ActivationType.Silu
and not doweight_stage1
):
from aiter.fused_moe_pyhip_hsaco import fused_moe_hsaco_batch1_fwd

_hsaco_b1 = fused_moe_hsaco_batch1_fwd(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
w1_scale,
w2_scale,
)
if _hsaco_b1 is not None:
return _hsaco_b1
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting(
topk_ids,
topk_weight,
Expand Down Expand Up @@ -660,6 +670,8 @@ class MOEMetadata:
has_bias: bool = False
use_non_temporal_load: bool = True
fuse_fp4_quant: bool = False
# HSACO pyhip ``moe_2stage_splitk`` uses BF16 activations from stage1; skip FP8 requant of ``a2``.
skip_a2_quant: bool = False


def _flydsl_stage1_wrapper(
Expand Down Expand Up @@ -1196,7 +1208,12 @@ def fused_moe_2stages(
intermediate_pad,
is_shuffled,
)
if (
# pyhip ``moe_gemm_batch`` gate kernel reads BF16 ``p_input`` (see pyhip ``moe.py`` ``buff_a``);
# do not feed FP8-quantized activations from ``get_quant(per_Token)`` into that path.
if metadata.skip_a2_quant:
a1 = hidden_states.to(dtype) if hidden_states.dtype != dtype else hidden_states
a1_scale = None
elif (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and w1.dtype == dtypes.fp4x2
Expand Down Expand Up @@ -1269,6 +1286,8 @@ def fused_moe_2stages(
dtype=dtype,
device=device,
)
if metadata.skip_a2_quant:
a2.zero_()
extra_stage1_args = {}
extra_stage2_args = {}
if (
Expand All @@ -1294,7 +1313,9 @@ def fused_moe_2stages(
w1_scale=(
w1_scale.view(dtypes.fp8_e8m0) if w1.dtype == dtypes.fp4x2 else w1_scale
),
sorted_weights=sorted_weights if doweight_stage1 else None,
sorted_weights=sorted_weights
if (doweight_stage1 or metadata.skip_a2_quant)
else None,
**extra_stage1_args,
)
if metadata.fuse_fp4_quant and isinstance(a2, tuple):
Expand Down Expand Up @@ -1347,6 +1368,9 @@ def fused_moe_2stages(
.view(token_num, -1)
)
a2 = a2_v
elif metadata.skip_a2_quant:
a2_scale = None
a2 = a2.view(token_num, topk, inter_dim)
else:
a2, a2_scale = quant_func(
a2,
Expand All @@ -1357,6 +1381,10 @@ def fused_moe_2stages(
)
a2 = a2.view(token_num, topk, inter_dim)

# moe_2stage_splitk .co accumulates with global atomics into ``moe_out``; buffer must start at zero.
if metadata.skip_a2_quant:
moe_out.zero_()

metadata.stage2(
a2,
w1,
Expand Down
Loading