Skip to content

Commit ae80f53

Browse files
feat(topk): optimized topk_gating kernel with sigmoid/softmax support (#3100)
* feat(topk): optimized topk_gating kernel with sigmoid/softmax support Replace topk_softplus with a unified topk_gating kernel that supports three scoring functions ("sqrtsoftplus" | "sigmoid" | "softmax") via a templated SCORE_FUNC parameter, plus a Python-facing topk_gating() wrapper. Performance (vs reference kernels on MI355, bf16): sqrtsoftplus E=384 T=1024 K=6: 81.3 us -> 8.7 us (9.3x) sigmoid E=128 T=1 K=8 : CK 5.09 us -> 2.12 us (2.41x) sigmoid E=128 T=1024 K=8 : CK 6.06 us -> 3.73 us (1.62x) softmax E=256 T=1024 K=8 : vllm 7.95 us -> 6.32 us (1.26x) (vllm topk_softmax still wins for E<=32 / K<=2; we win for E>=64 + K>=4.) Implementation highlights (register-only opt kernel): - 1 warp = 1 token, no shared memory, no __syncthreads - Each thread holds EPT = NUM_EXPERTS / WARP_SIZE elements - Optimal sorting networks for N=2,3,4,6 (vs unrolled bubble sort) - Fused DPP warp argmax (multithread_reduce_max_dpp + __ballot + readlane) -> ~9 instructions per warp argmax, NaN-safe via ballot fallback - Cached unbiased scores so the merge phase reads no global memory - fmaxf(sum, 1e-20f) renorm clamp to handle pathological all-zero rows - Generic shared-memory fallback for non-power-of-WARP_SIZE expert counts (also handles softmax: max-reduce -> exp -> sum-reduce -> divide) Stream management intentionally uses at::hip::getCurrentHIPStream() for correctness with non-default-stream callers (CUDA graphs, EP/TP comm streams). Switching to aiter::getCurrentHIPStream() requires migrating the whole pybind module to aiter_tensor_t + develop=True; tracked as TODO in the kernel. Test coverage: - op_tests/test_moe_topk_gating.py (renamed from test_moe_topk_sigmoid.py) covers sqrtsoftplus / sigmoid / softmax across {fp16, bf16, fp32} inputs, {fp16, bf16, fp32} biases, num_experts in {64, 128, 256, 384}, topk in {2, 4, 6, 8}, num_tokens in {1, 64, 256, 1024}. Co-authored-by: Cursor <cursoragent@cursor.com> * add more check * fix test error --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 1c7e7b6 commit ae80f53

7 files changed

Lines changed: 827 additions & 268 deletions

File tree

aiter/jit/optCompilerConfig.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@
534534
"srcs": [
535535
"f'{AITER_CSRC_DIR}/pybind/moe_topk_pybind.cu'",
536536
"f'{AITER_CSRC_DIR}/py_itfs_ck/topk_sigmoid_kernels.cu'",
537-
"f'{AITER_CSRC_DIR}/kernels/topk_softplus_kernels.cu'",
537+
"f'{AITER_CSRC_DIR}/kernels/topk_gating_kernels.cu'",
538538
"f'{CK_DIR}/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp'"
539539
],
540540
"flags_extra_cc": [],

aiter/ops/topk.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from ..utility import dtypes
1313

1414

15+
# DEPRECATED: low-level binding kept for backward compatibility only.
16+
# Will be removed once all callers have migrated to topk_gating() below.
17+
# New code should use topk_gating(), which:
18+
# - accepts an Optional[Tensor] correction_bias (None => no bias)
19+
# - validates score_func string
20+
# - exposes the same C++ kernel under a more accurate name
1521
@compile_ops("module_moe_topk")
1622
def topk_softplus(
1723
topk_weights: torch.Tensor,
@@ -20,9 +26,54 @@ def topk_softplus(
2026
correction_bias: torch.Tensor,
2127
need_renorm: bool,
2228
routed_scaling_factor: float = 1.0,
29+
score_func: str = "sqrtsoftplus",
2330
) -> None: ...
2431

2532

33+
_VALID_SCORE_FUNCS = {"sqrtsoftplus", "sigmoid", "softmax"}
34+
35+
36+
def topk_gating(
37+
topk_weights: torch.Tensor,
38+
topk_indices: torch.Tensor,
39+
gating_output: torch.Tensor,
40+
correction_bias: Optional[torch.Tensor] = None,
41+
need_renorm: bool = True,
42+
routed_scaling_factor: float = 1.0,
43+
score_func: str = "sqrtsoftplus",
44+
) -> None:
45+
"""Unified fused topk gating for MoE routing.
46+
47+
Args:
48+
score_func: one of {"sqrtsoftplus" (DeepSeek V4-Pro default),
49+
"sigmoid" (Llama4),
50+
"softmax" (DeepSeek V3 / classic MoE)}.
51+
correction_bias: optional bias tensor, pass None for no bias.
52+
53+
Note: softmax is already normalized, so renorm is forced off.
54+
"""
55+
assert (
56+
score_func in _VALID_SCORE_FUNCS
57+
), f"Unknown score_func '{score_func}', expected one of {_VALID_SCORE_FUNCS}"
58+
if correction_bias is None:
59+
# Match gating dtype/device so dispatch picks DTYPE_B == DTYPE_I,
60+
# avoiding extra kernel template instantiations.
61+
correction_bias = torch.empty(
62+
0, dtype=gating_output.dtype, device=gating_output.device
63+
)
64+
if score_func == "softmax":
65+
need_renorm = False
66+
topk_softplus(
67+
topk_weights,
68+
topk_indices,
69+
gating_output,
70+
correction_bias,
71+
need_renorm,
72+
routed_scaling_factor,
73+
score_func,
74+
)
75+
76+
2677
@compile_ops("module_moe_asm", fc_name="biased_grouped_topk")
2778
def biased_grouped_topk_hip(
2879
gating_output: torch.Tensor,

csrc/include/moe_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ void topk_softplus(torch::Tensor& topk_weights,
5555
torch::Tensor& gating_output,
5656
torch::Tensor& correction_bias,
5757
bool need_renorm,
58-
float routed_scaling_factor = 1.0);
58+
float routed_scaling_factor = 1.0,
59+
const std::string& score_func = "sqrtsoftplus");
5960

6061
void moe_align_block_size(torch::Tensor topk_ids,
6162
int64_t num_experts,

csrc/include/rocm_ops.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,8 @@ namespace py = pybind11;
11801180
py::arg("correction_bias"), \
11811181
py::arg("need_renorm"), \
11821182
py::arg("routed_scaling_factor") = 1.0, \
1183-
"Apply topk sqrtsoftplus to the gating outputs.");
1183+
py::arg("score_func") = "sqrtsoftplus", \
1184+
"Fused topk gating: score_func='sqrtsoftplus'|'sigmoid'|'softmax'.");
11841185

11851186
#define MOE_SORTING_PYBIND \
11861187
m.def("moe_sorting_fwd", \

0 commit comments

Comments
 (0)