Skip to content

feat(topk): optimized topk_gating kernel with sigmoid/softmax support#3100

Merged
valarLip merged 6 commits into
ROCm:mainfrom
yzhou103:opt_topk_softplus_bias
May 12, 2026
Merged

feat(topk): optimized topk_gating kernel with sigmoid/softmax support#3100
valarLip merged 6 commits into
ROCm:mainfrom
yzhou103:opt_topk_softplus_bias

Conversation

@yzhou103
Copy link
Copy Markdown
Contributor

@yzhou103 yzhou103 commented May 9, 2026

Motivation

Replace topk_softplus with a unified topk_gating kernel that supports three scoring functions ("sqrtsoftplus" | "sigmoid" | "softmax") via a templated SCORE_FUNC parameter, plus a Python-facing topk_gating() wrapper.

Performance (vs reference kernels on MI355, bf16):
sqrtsoftplus E=384 T=1024 K=6: 81.3 us -> 8.7 us (9.3x)
sigmoid E=128 T=1 K=8 : CK 5.09 us -> 2.12 us (2.41x)
sigmoid E=128 T=1024 K=8 : CK 6.06 us -> 3.73 us (1.62x)
softmax E=256 T=1024 K=8 : vllm 7.95 us -> 6.32 us (1.26x)
(vllm topk_softmax still wins for E<=32 / K<=2;
we win for E>=64 + K>=4.)

Technical Details

Implementation highlights (register-only opt kernel):

  • 1 warp = 1 token, no shared memory, no __syncthreads
  • Each thread holds EPT = NUM_EXPERTS / WARP_SIZE elements
  • Optimal sorting networks for N=2,3,4,6 (vs unrolled bubble sort)
  • Fused DPP warp argmax (multithread_reduce_max_dpp + __ballot + readlane) -> ~9 instructions per warp argmax, NaN-safe via ballot fallback
  • Cached unbiased scores so the merge phase reads no global memory
  • fmaxf(sum, 1e-20f) renorm clamp to handle pathological all-zero rows
  • Generic shared-memory fallback for non-power-of-WARP_SIZE expert counts (also handles softmax: max-reduce -> exp -> sum-reduce -> divide)

Stream management intentionally uses at::hip::getCurrentHIPStream() for correctness with non-default-stream callers (CUDA graphs, EP/TP comm streams). Switching to aiter::getCurrentHIPStream() requires migrating the whole pybind module to aiter_tensor_t + develop=True; tracked as TODO in the kernel.

Test Plan

Test coverage:

  • op_tests/test_moe_topk_gating.py (renamed from test_moe_topk_sigmoid.py) covers sqrtsoftplus / sigmoid / softmax across {fp16, bf16, fp32} inputs, {fp16, bf16, fp32} biases, num_experts in {64, 128, 256, 384}, topk in {2, 4, 6, 8}, num_tokens in {1, 64, 256, 1024}.

Test Result

Submission Checklist

@yzhou103 yzhou103 requested review from a team and Copilot May 9, 2026 06:22
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All of the above

Add labels via the sidebar or gh pr edit 3100 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR replaces the dedicated topk_softplus kernel implementation with a unified, optimized fused “topk gating” kernel that supports multiple MoE routing score functions (sqrtsoftplus, sigmoid, softmax) and exposes it via a Python-facing wrapper while keeping the existing topk_softplus entrypoint as the underlying binding.

Changes:

  • Added a new HIP/C++ fused gating kernel (topk_gating_kernels.cu) with compile-time score-function selection and opt/fallback execution paths.
  • Extended the C++/pybind and Python APIs to accept a score_func selector and added a topk_gating() convenience wrapper.
  • Expanded op_test coverage/benchmarks to exercise sqrtsoftplus, sigmoid, and softmax routing paths.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
op_tests/test_moe_topk_gating.py Updates benchmarks/tests to use topk_gating, adds softmax coverage, and expands parameter grids.
csrc/kernels/topk_softplus_kernels.cu Removes the previous standalone softplus kernel implementation.
csrc/kernels/topk_gating_kernels.cu Adds the new unified fused topk gating kernel with score-function dispatch and optimized path.
csrc/include/rocm_ops.hpp Updates pybind signature/docs to include the new score_func argument.
csrc/include/moe_op.h Extends topk_softplus declaration to accept score_func.
aiter/ops/topk.py Adds Python topk_gating() wrapper and validates score_func.
aiter/jit/optCompilerConfig.json Switches JIT compilation sources from removed softplus kernel to the new gating kernel.
Comments suppressed due to low confidence (4)

op_tests/test_moe_topk_gating.py:40

  • The perftest wrapper always runs num_iters iterations under torch.profiler even when called from correctness checks and the CI harness runs this file directly. Bumping num_iters to 100 here is likely to make this op_test much slower and risk hitting the 60-minute per-file CI timeout; consider keeping CI defaults small (e.g., 10) and making higher-iteration benchmarking opt-in via a CLI flag/env var.
    op_tests/test_moe_topk_gating.py:77
  • Same concern as above: @perftest(num_iters=100, ...) runs 100 profiled iterations per call. Since this script is executed directly in CI, this setting can significantly increase runtime across the many parameter combinations; consider reducing num_iters for the default path and gating longer perf runs behind an explicit benchmark mode.
    op_tests/test_moe_topk_gating.py:113
  • @perftest(num_iters=100, ...) in this script will run 100 profiled iterations for every softmax config. With the expanded default config grid, this can substantially slow CI runs; consider using a smaller default iteration count for op_tests and making longer perf runs opt-in.
    op_tests/test_moe_topk_gating.py:565
  • In the __main__ path, configs with id_errors > 0.01 only print a warning and still exit 0. Since CI runs op_tests scripts directly, correctness regressions in the new softmax path may not fail CI; consider raising an exception / exiting non-zero when errors exceed the threshold.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/topk_gating_kernels.cu
Comment thread csrc/kernels/topk_gating_kernels.cu
Replace topk_softplus with a unified topk_gating kernel that supports
three scoring functions ("sqrtsoftplus" | "sigmoid" | "softmax") via a
templated SCORE_FUNC parameter, plus a Python-facing topk_gating()
wrapper.

Performance (vs reference kernels on MI355, bf16):
  sqrtsoftplus E=384 T=1024 K=6: 81.3 us -> 8.7 us  (9.3x)
  sigmoid     E=128 T=1   K=8 : CK   5.09 us -> 2.12 us (2.41x)
  sigmoid     E=128 T=1024 K=8 : CK   6.06 us -> 3.73 us (1.62x)
  softmax     E=256 T=1024 K=8 : vllm 7.95 us -> 6.32 us (1.26x)
  (vllm topk_softmax still wins for E<=32 / K<=2;
   we win for E>=64 + K>=4.)

Implementation highlights (register-only opt kernel):
- 1 warp = 1 token, no shared memory, no __syncthreads
- Each thread holds EPT = NUM_EXPERTS / WARP_SIZE elements
- Optimal sorting networks for N=2,3,4,6 (vs unrolled bubble sort)
- Fused DPP warp argmax (multithread_reduce_max_dpp + __ballot + readlane)
  -> ~9 instructions per warp argmax, NaN-safe via ballot fallback
- Cached unbiased scores so the merge phase reads no global memory
- fmaxf(sum, 1e-20f) renorm clamp to handle pathological all-zero rows
- Generic shared-memory fallback for non-power-of-WARP_SIZE expert counts
  (also handles softmax: max-reduce -> exp -> sum-reduce -> divide)

Stream management intentionally uses at::hip::getCurrentHIPStream()
for correctness with non-default-stream callers (CUDA graphs, EP/TP
comm streams).  Switching to aiter::getCurrentHIPStream() requires
migrating the whole pybind module to aiter_tensor_t + develop=True;
tracked as TODO in the kernel.

Test coverage:
- op_tests/test_moe_topk_gating.py (renamed from test_moe_topk_sigmoid.py)
  covers sqrtsoftplus / sigmoid / softmax across {fp16, bf16, fp32}
  inputs, {fp16, bf16, fp32} biases, num_experts in {64, 128, 256, 384},
  topk in {2, 4, 6, 8}, num_tokens in {1, 64, 256, 1024}.

Co-authored-by: Cursor <cursoragent@cursor.com>
@yzhou103 yzhou103 force-pushed the opt_topk_softplus_bias branch from b4fce19 to a0fde42 Compare May 9, 2026 06:49
@yzhou103
Copy link
Copy Markdown
Contributor Author

yzhou103 commented May 9, 2026

image

@yzhou103
Copy link
Copy Markdown
Contributor Author

image

@valarLip valarLip merged commit ae80f53 into ROCm:main May 12, 2026
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants