feat(topk): optimized topk_gating kernel with sigmoid/softmax support#3100
Merged
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
This PR replaces the dedicated topk_softplus kernel implementation with a unified, optimized fused “topk gating” kernel that supports multiple MoE routing score functions (sqrtsoftplus, sigmoid, softmax) and exposes it via a Python-facing wrapper while keeping the existing topk_softplus entrypoint as the underlying binding.
Changes:
- Added a new HIP/C++ fused gating kernel (
topk_gating_kernels.cu) with compile-time score-function selection and opt/fallback execution paths. - Extended the C++/pybind and Python APIs to accept a
score_funcselector and added atopk_gating()convenience wrapper. - Expanded op_test coverage/benchmarks to exercise
sqrtsoftplus,sigmoid, andsoftmaxrouting paths.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_moe_topk_gating.py | Updates benchmarks/tests to use topk_gating, adds softmax coverage, and expands parameter grids. |
| csrc/kernels/topk_softplus_kernels.cu | Removes the previous standalone softplus kernel implementation. |
| csrc/kernels/topk_gating_kernels.cu | Adds the new unified fused topk gating kernel with score-function dispatch and optimized path. |
| csrc/include/rocm_ops.hpp | Updates pybind signature/docs to include the new score_func argument. |
| csrc/include/moe_op.h | Extends topk_softplus declaration to accept score_func. |
| aiter/ops/topk.py | Adds Python topk_gating() wrapper and validates score_func. |
| aiter/jit/optCompilerConfig.json | Switches JIT compilation sources from removed softplus kernel to the new gating kernel. |
Comments suppressed due to low confidence (4)
op_tests/test_moe_topk_gating.py:40
- The perftest wrapper always runs
num_itersiterations undertorch.profilereven when called from correctness checks and the CI harness runs this file directly. Bumpingnum_itersto 100 here is likely to make this op_test much slower and risk hitting the 60-minute per-file CI timeout; consider keeping CI defaults small (e.g., 10) and making higher-iteration benchmarking opt-in via a CLI flag/env var.
op_tests/test_moe_topk_gating.py:77 - Same concern as above:
@perftest(num_iters=100, ...)runs 100 profiled iterations per call. Since this script is executed directly in CI, this setting can significantly increase runtime across the many parameter combinations; consider reducingnum_itersfor the default path and gating longer perf runs behind an explicit benchmark mode.
op_tests/test_moe_topk_gating.py:113 @perftest(num_iters=100, ...)in this script will run 100 profiled iterations for every softmax config. With the expanded default config grid, this can substantially slow CI runs; consider using a smaller default iteration count for op_tests and making longer perf runs opt-in.
op_tests/test_moe_topk_gating.py:565- In the
__main__path, configs withid_errors > 0.01only print a warning and still exit 0. Since CI runs op_tests scripts directly, correctness regressions in the new softmax path may not fail CI; consider raising an exception / exiting non-zero when errors exceed the threshold.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Replace topk_softplus with a unified topk_gating kernel that supports
three scoring functions ("sqrtsoftplus" | "sigmoid" | "softmax") via a
templated SCORE_FUNC parameter, plus a Python-facing topk_gating()
wrapper.
Performance (vs reference kernels on MI355, bf16):
sqrtsoftplus E=384 T=1024 K=6: 81.3 us -> 8.7 us (9.3x)
sigmoid E=128 T=1 K=8 : CK 5.09 us -> 2.12 us (2.41x)
sigmoid E=128 T=1024 K=8 : CK 6.06 us -> 3.73 us (1.62x)
softmax E=256 T=1024 K=8 : vllm 7.95 us -> 6.32 us (1.26x)
(vllm topk_softmax still wins for E<=32 / K<=2;
we win for E>=64 + K>=4.)
Implementation highlights (register-only opt kernel):
- 1 warp = 1 token, no shared memory, no __syncthreads
- Each thread holds EPT = NUM_EXPERTS / WARP_SIZE elements
- Optimal sorting networks for N=2,3,4,6 (vs unrolled bubble sort)
- Fused DPP warp argmax (multithread_reduce_max_dpp + __ballot + readlane)
-> ~9 instructions per warp argmax, NaN-safe via ballot fallback
- Cached unbiased scores so the merge phase reads no global memory
- fmaxf(sum, 1e-20f) renorm clamp to handle pathological all-zero rows
- Generic shared-memory fallback for non-power-of-WARP_SIZE expert counts
(also handles softmax: max-reduce -> exp -> sum-reduce -> divide)
Stream management intentionally uses at::hip::getCurrentHIPStream()
for correctness with non-default-stream callers (CUDA graphs, EP/TP
comm streams). Switching to aiter::getCurrentHIPStream() requires
migrating the whole pybind module to aiter_tensor_t + develop=True;
tracked as TODO in the kernel.
Test coverage:
- op_tests/test_moe_topk_gating.py (renamed from test_moe_topk_sigmoid.py)
covers sqrtsoftplus / sigmoid / softmax across {fp16, bf16, fp32}
inputs, {fp16, bf16, fp32} biases, num_experts in {64, 128, 256, 384},
topk in {2, 4, 6, 8}, num_tokens in {1, 64, 256, 1024}.
Co-authored-by: Cursor <cursoragent@cursor.com>
b4fce19 to
a0fde42
Compare
Contributor
Author
Contributor
Author
valarLip
approved these changes
May 12, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.


Motivation
Replace topk_softplus with a unified topk_gating kernel that supports three scoring functions ("sqrtsoftplus" | "sigmoid" | "softmax") via a templated SCORE_FUNC parameter, plus a Python-facing topk_gating() wrapper.
Performance (vs reference kernels on MI355, bf16):
sqrtsoftplus E=384 T=1024 K=6: 81.3 us -> 8.7 us (9.3x)
sigmoid E=128 T=1 K=8 : CK 5.09 us -> 2.12 us (2.41x)
sigmoid E=128 T=1024 K=8 : CK 6.06 us -> 3.73 us (1.62x)
softmax E=256 T=1024 K=8 : vllm 7.95 us -> 6.32 us (1.26x)
(vllm topk_softmax still wins for E<=32 / K<=2;
we win for E>=64 + K>=4.)
Technical Details
Implementation highlights (register-only opt kernel):
Stream management intentionally uses at::hip::getCurrentHIPStream() for correctness with non-default-stream callers (CUDA graphs, EP/TP comm streams). Switching to aiter::getCurrentHIPStream() requires migrating the whole pybind module to aiter_tensor_t + develop=True; tracked as TODO in the kernel.
Test Plan
Test coverage:
Test Result
Submission Checklist