[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776
Conversation
|
Test on H200 ,CUDA version 13.0 |
Greptile SummaryThis PR optimizes the fused-router CUDA kernels by specializing
Confidence Score: 5/5Safe to merge; changes are a focused set of GPU micro-optimizations with correct warp semantics and an important forward-dispatch bug fix. The expert_bias.has_data() fix removes a real crash path that previously dereferenced an invalid dtype field when no bias was supplied. The warp reduce and smallk specialization are mechanically correct: CompType is float throughout, the broadcast after __shfl_down_sync is present and load-bearing, and the register-based masking in smallk is consistent across all lanes via the broadcast before the selected[] write. No data-correctness or memory-safety issues were found. transformer_engine/common/fused_router/utils.h — the new naive_topk_and_mask_smallk path warrants a close read, specifically the interaction between the per-lane register array (selected[]) and the lane-0-only shared-memory writes (topk_indices/topk_scores) protected by __syncwarp(). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["naive_topk_and_mask(scores, data_size, topk, ...)"] --> B{topk <= 8?}
B -- "case 1..8" --> C["naive_topk_and_mask_smallk<K>(...)"]
B -- "default" --> D["naive_topk_and_mask_generic(...)"]
C --> E["Register array: selected[K] = {-1...}"]
E --> F["unroll k=0..K-1"]
F --> G["Inner loop: per-lane max, skip if selected[j]==i"]
G --> H["XOR butterfly reduce (all lanes get max)"]
H --> I["shfl_sync broadcast → canonical chosen_index"]
I --> J["lane==0: write topk_indices/scores shmem\nAll lanes: selected[k]=chosen_index\n__syncwarp()"]
J --> F
D --> K["is_masked lambda via topk_indices shmem"]
K --> L["Loop k=0..topk-1, XOR reduce, lane 0 writes, __syncwarp()"]
subgraph "Forward dispatch"
M["expert_bias.has_data()?"] -- yes --> N["BiasType from bias.dtype"]
M -- no --> O["BiasType=DataType, expert_bias=nullptr"]
end
subgraph "warp_reduce_sum_float"
P["shfl_down_sync → accumulate to lane 0"]
P --> Q["shfl_sync broadcast → all lanes"]
end
Reviews (9): Last reviewed commit: "fused_router: address review feedback" | Re-trigger Greptile |
| __device__ inline float warp_reduce_sum_float(float val) { | ||
| for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { | ||
| val += __shfl_down_sync(0xffffffff, val, offset); | ||
| } | ||
| return __shfl_sync(0xffffffff, val, 0); | ||
| } |
There was a problem hiding this comment.
Broadcast from lane 0 should be documented
__shfl_down_sync accumulates the sum only in lane 0 after all steps — unlike the __shfl_xor_sync butterfly approach (used elsewhere in this file) which gives the correct sum to every lane simultaneously. The subsequent __shfl_sync(…, 0) is therefore load-bearing for correctness, not just an optimisation.
Adding a short comment here prevents a future reader from accidentally removing it thinking it's redundant:
| __device__ inline float warp_reduce_sum_float(float val) { | |
| for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { | |
| val += __shfl_down_sync(0xffffffff, val, offset); | |
| } | |
| return __shfl_sync(0xffffffff, val, 0); | |
| } | |
| __device__ inline float warp_reduce_sum_float(float val) { | |
| // __shfl_down_sync accumulates the total only in lane 0; | |
| // the broadcast below is required for all lanes to see the result. | |
| for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { | |
| val += __shfl_down_sync(0xffffffff, val, offset); | |
| } | |
| return __shfl_sync(0xffffffff, val, 0); | |
| } |
| #pragma unroll | ||
| for (int k = 0; k < K; ++k) { | ||
| CompType val = -std::numeric_limits<CompType>::infinity(); | ||
| int index = (lane_id < data_size) ? lane_id : 0; | ||
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | ||
| bool masked = false; | ||
| #pragma unroll | ||
| for (int j = 0; j < k; ++j) { | ||
| masked |= (selected[j] == i); | ||
| } | ||
| if (masked) continue; | ||
| CompType cur_val = scores[i]; | ||
| if (cur_val > val) { | ||
| val = cur_val; | ||
| index = i; | ||
| } | ||
| } | ||
| for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { | ||
| auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); | ||
| auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); | ||
| if (shuffled_val > val) { | ||
| val = shuffled_val; | ||
| index = shuffled_index; | ||
| } | ||
| } |
There was a problem hiding this comment.
OOB thread
index initialised to 0 could shadow earlier selections
For threads where lane_id >= data_size, the inner loop body never executes, so val stays at -inf and index is set to 0:
int index = (lane_id < data_size) ? lane_id : 0;0 is a valid data index that may already have been placed in selected by a previous k iteration. During the XOR-reduction phase, these OOB threads participate (they shuffle -inf values, which can never win the shuffled_val > val comparison), so the final chosen_index remains correct. However, after the broadcast:
selected[k] = chosen_index;every thread — including OOB ones — writes chosen_index to their register copy of selected, keeping all threads in sync. The algorithm is therefore correct, but initialising the fallback index to a sentinel value (e.g., -1 or data_size - 1) would make the intent clearer and avoid confusion with a real element:
| #pragma unroll | |
| for (int k = 0; k < K; ++k) { | |
| CompType val = -std::numeric_limits<CompType>::infinity(); | |
| int index = (lane_id < data_size) ? lane_id : 0; | |
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | |
| bool masked = false; | |
| #pragma unroll | |
| for (int j = 0; j < k; ++j) { | |
| masked |= (selected[j] == i); | |
| } | |
| if (masked) continue; | |
| CompType cur_val = scores[i]; | |
| if (cur_val > val) { | |
| val = cur_val; | |
| index = i; | |
| } | |
| } | |
| for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { | |
| auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); | |
| auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); | |
| if (shuffled_val > val) { | |
| val = shuffled_val; | |
| index = shuffled_index; | |
| } | |
| } | |
| int index = (lane_id < data_size) ? lane_id : -1; // -1: sentinel for out-of-range lane |
This is purely a readability / defensive-programming concern given that invalid index values can never propagate to the output.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
|
||
| seed = 42 | ||
| torch.manual_seed(seed) | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.manual_seed(seed) |
There was a problem hiding this comment.
Module-level seed is fragile for
tokens_per_expert randomness
The random seed is set once at import time. If pytest collects or runs other tests before this module's tests execute, the global random state will have advanced and test_fused_moe_aux_loss_perf_against_torch (which calls torch.randint for tokens_per_expert) will use an unknown seed. While the numerical correctness check in that test (torch.testing.assert_close(torch_loss, fused_loss)) passes regardless of the specific random values, reproducible benchmarks are easier to debug.
Consider moving the seed setup into each individual test function, or using a pytest fixture to ensure a consistent state per test.
There was a problem hiding this comment.
@XiaomingFun233 I went through the Greptile comments and I agree with the suggested changes except for the speedup assertion — we should keep the benchmark test in the suite for manual verification but our CI pipeline is for functional testing, not benchmarking, so the assertion does not make sense. Could you address the remaining issues and rebase the branch on latest TE/main? We can launch the CI on our end for testing.
Please also check out the contributing guidelines, particularly regarding the sign-off for your commits and the license information that needs to be added to the source files.
Thanks!
ok I will complete this work |
- restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Add a lightweight register-based small-k path and keep the generic fallback for compatibility. Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels. Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
066b9ea to
3ad3ad2
Compare
for more information, see https://pre-commit.ci
|
Addressed the remaining issues from the review.
Please let me know if there is anything else you would like me to adjust |
|
please checkout this new change @denera |
|
@ptrendx please check out |
Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
|
@denera Addressed the remaining review feedback.
The benchmark test remains opt-in for manual verification, perf-style assertions are removed, the branch has been rebased onto the latest TE/main, |
|
/te-ci pytorch |
…rk (NVIDIA#2776) * fused_router: keep low-risk CUDA optimizations - restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> * fused_router: specialize naive_topk_and_mask for topk<=8 Add a lightweight register-based small-k path and keep the generic fallback for compatibility. Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> * tests: add fused router performance benchmark Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels. Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fused_router: address review feedback Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> --------- Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
…rk (NVIDIA#2776) * fused_router: keep low-risk CUDA optimizations - restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> * fused_router: specialize naive_topk_and_mask for topk<=8 Add a lightweight register-based small-k path and keep the generic fallback for compatibility. Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> * tests: add fused router performance benchmark Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels. Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fused_router: address review feedback Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> --------- Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Summary
This PR ports and keeps a focused set of CUDA fused-router optimizations that showed consistent gains on the tested workload, while avoiding heavier variants that regressed performance.
1. Add fused-router performance benchmark test
tests/pytorch/test_fused_router_perf.py.fused_topk_with_score_functionfused_compute_score_for_moe_aux_lossfused_moe_aux_loss2. Keep low-risk fused-router CUDA optimizations
transformer_engine/common/fused_router/utils.htransformer_engine/common/fused_router/fused_topk_with_score_function.cuexpert_bias.has_data()handling in forward to avoid invalid dtype switch when bias is absent.transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu3. Optimize
naive_topk_and_maskfor small-ktransformer_engine/common/fused_router/utils.htopk <= 8.Performance (A/B)
Measured with:
TE_RUN_PERF_TESTS=1 pytest -q tests/pytorch/test_fused_router_perf.py -sBefore
topk_router[softmax]: fused0.029562 ms, speedup8.3067xtopk_router[sigmoid]: fused0.030138 ms, speedup7.2715xscores_for_aux_loss[softmax]: fused0.026183 ms, speedup3.8721xscores_for_aux_loss[sigmoid]: fused0.025872 ms, speedup3.8892xmoe_aux_loss: fused0.015680 ms, speedup1.8884xAfter
topk_router[softmax]: fused0.022384 ms, speedup11.1324xtopk_router[sigmoid]: fused0.022840 ms, speedup9.7714xscores_for_aux_loss[softmax]: fused0.017230 ms, speedup5.9707xscores_for_aux_loss[sigmoid]: fused0.017049 ms, speedup6.0205xmoe_aux_loss: fused0.015412 ms, speedup1.8424xNotes
topk_router/scores_for_aux_lossperformance on this setup.