Skip to content

[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776

Merged
denera merged 11 commits into
NVIDIA:mainfrom
XiaomingFun233:pr/fused-router-topk-opt
May 28, 2026
Merged

[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776
denera merged 11 commits into
NVIDIA:mainfrom
XiaomingFun233:pr/fused-router-topk-opt

Conversation

@XiaomingFun233

Copy link
Copy Markdown
Contributor

Summary

This PR ports and keeps a focused set of CUDA fused-router optimizations that showed consistent gains on the tested workload, while avoiding heavier variants that regressed performance.

1. Add fused-router performance benchmark test

  • Add tests/pytorch/test_fused_router_perf.py.
  • Benchmark coverage:
    • fused_topk_with_score_function
    • fused_compute_score_for_moe_aux_loss
    • fused_moe_aux_loss

2. Keep low-risk fused-router CUDA optimizations

  • transformer_engine/common/fused_router/utils.h
    • Add warp-level sum helper used in backward normalization path.
  • transformer_engine/common/fused_router/fused_topk_with_score_function.cu
    • Use warp-level sum reduction in backward normalization.
    • Add safe expert_bias.has_data() handling in forward to avoid invalid dtype switch when bias is absent.
  • transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
    • Use warp-level sum reduction in backward normalization.

3. Optimize naive_topk_and_mask for small-k

  • transformer_engine/common/fused_router/utils.h
    • Add lightweight specialization for topk <= 8.
    • Keep generic fallback for compatibility.

Performance (A/B)

Measured with:

  • TE_RUN_PERF_TESTS=1 pytest -q tests/pytorch/test_fused_router_perf.py -s

Before

  • topk_router[softmax]: fused 0.029562 ms, speedup 8.3067x
  • topk_router[sigmoid]: fused 0.030138 ms, speedup 7.2715x
  • scores_for_aux_loss[softmax]: fused 0.026183 ms, speedup 3.8721x
  • scores_for_aux_loss[sigmoid]: fused 0.025872 ms, speedup 3.8892x
  • moe_aux_loss: fused 0.015680 ms, speedup 1.8884x

After

  • topk_router[softmax]: fused 0.022384 ms, speedup 11.1324x
  • topk_router[sigmoid]: fused 0.022840 ms, speedup 9.7714x
  • scores_for_aux_loss[softmax]: fused 0.017230 ms, speedup 5.9707x
  • scores_for_aux_loss[sigmoid]: fused 0.017049 ms, speedup 6.0205x
  • moe_aux_loss: fused 0.015412 ms, speedup 1.8424x

Notes

  • This PR intentionally avoids the larger full-port variant that previously regressed topk_router/scores_for_aux_loss performance on this setup.

@XiaomingFun233

Copy link
Copy Markdown
Contributor Author

Test on H200 ,CUDA version 13.0

@greptile-apps

greptile-apps Bot commented Mar 18, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR optimizes the fused-router CUDA kernels by specializing naive_topk_and_mask for small-k (≤8) with a template that unrolls masking checks into registers, extracts a warp_reduce_sum_float helper replacing the inline __shfl_xor_sync loops, and fixes a crash path in the forward dispatch when expert_bias is absent. A new benchmark-only test file covers all three router operations.

  • naive_topk_and_mask_smallk<K>: per-thread selected[K] register array avoids repeated shared-memory reads for masking; XOR butterfly gives all lanes the maximum, then a __shfl_sync broadcast establishes a canonical chosen_index; __syncwarp() after each k-step makes the lane-0 shmem writes visible before the next read.
  • expert_bias.has_data() guard: previously expert_bias.data.dtype was always dereferenced in the type-switch macro, causing undefined behavior when no bias tensor was provided; the else branch now passes nullptr, which the existing in-kernel null-check handles.
  • warp_reduce_sum_float: uses __shfl_down_sync (accumulates to lane 0) plus an explicit __shfl_sync broadcast; functionally identical to the removed inline __shfl_xor_sync butterfly but extracted for reuse.

Confidence Score: 5/5

Safe to merge; changes are a focused set of GPU micro-optimizations with correct warp semantics and an important forward-dispatch bug fix.

The expert_bias.has_data() fix removes a real crash path that previously dereferenced an invalid dtype field when no bias was supplied. The warp reduce and smallk specialization are mechanically correct: CompType is float throughout, the broadcast after __shfl_down_sync is present and load-bearing, and the register-based masking in smallk is consistent across all lanes via the broadcast before the selected[] write. No data-correctness or memory-safety issues were found.

transformer_engine/common/fused_router/utils.h — the new naive_topk_and_mask_smallk path warrants a close read, specifically the interaction between the per-lane register array (selected[]) and the lane-0-only shared-memory writes (topk_indices/topk_scores) protected by __syncwarp().

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/utils.h Adds warp_reduce_sum_float helper (shfl_down_sync + broadcast) and naive_topk_and_mask_smallk template for topk≤8; OOB lane index correctly set to -1; fallback index=lane_id for in-bounds threads whose assigned elements are all masked is benign for valid inputs where topk≤data_size
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Bug fix: adds expert_bias.has_data() guard before accessing expert_bias.data.dtype; kernel already null-checks expert_bias so passing nullptr is safe; backward warp-reduce refactored to warp_reduce_sum_float helper
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Backward normalization warp reduction refactored from inline shfl_xor_sync loop to warp_reduce_sum_float helper; functionally equivalent, CompType=float matches the float parameter
tests/pytorch/test_fused_router_perf.py New benchmark-only test guarded by TE_RUN_PERF_TESTS; correctness verified on pre-benchmark results via assert_close; deterministic input generation; _set_seed() called per-function

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["naive_topk_and_mask(scores, data_size, topk, ...)"] --> B{topk <= 8?}
    B -- "case 1..8" --> C["naive_topk_and_mask_smallk<K>(...)"]
    B -- "default" --> D["naive_topk_and_mask_generic(...)"]
    C --> E["Register array: selected[K] = {-1...}"]
    E --> F["unroll k=0..K-1"]
    F --> G["Inner loop: per-lane max, skip if selected[j]==i"]
    G --> H["XOR butterfly reduce (all lanes get max)"]
    H --> I["shfl_sync broadcast → canonical chosen_index"]
    I --> J["lane==0: write topk_indices/scores shmem\nAll lanes: selected[k]=chosen_index\n__syncwarp()"]
    J --> F
    D --> K["is_masked lambda via topk_indices shmem"]
    K --> L["Loop k=0..topk-1, XOR reduce, lane 0 writes, __syncwarp()"]
    subgraph "Forward dispatch"
    M["expert_bias.has_data()?"] -- yes --> N["BiasType from bias.dtype"]
    M -- no --> O["BiasType=DataType, expert_bias=nullptr"]
    end
    subgraph "warp_reduce_sum_float"
    P["shfl_down_sync → accumulate to lane 0"]
    P --> Q["shfl_sync broadcast → all lanes"]
    end
Loading

Reviews (9): Last reviewed commit: "fused_router: address review feedback" | Re-trigger Greptile

Comment thread tests/pytorch/test_fused_router_perf.py
Comment on lines +41 to +46
__device__ inline float warp_reduce_sum_float(float val) {
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Broadcast from lane 0 should be documented

__shfl_down_sync accumulates the sum only in lane 0 after all steps — unlike the __shfl_xor_sync butterfly approach (used elsewhere in this file) which gives the correct sum to every lane simultaneously. The subsequent __shfl_sync(…, 0) is therefore load-bearing for correctness, not just an optimisation.

Adding a short comment here prevents a future reader from accidentally removing it thinking it's redundant:

Suggested change
__device__ inline float warp_reduce_sum_float(float val) {
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}
__device__ inline float warp_reduce_sum_float(float val) {
// __shfl_down_sync accumulates the total only in lane 0;
// the broadcast below is required for all lanes to see the result.
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}

Comment on lines +224 to +248
#pragma unroll
for (int k = 0; k < K; ++k) {
CompType val = -std::numeric_limits<CompType>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
bool masked = false;
#pragma unroll
for (int j = 0; j < k; ++j) {
masked |= (selected[j] == i);
}
if (masked) continue;
CompType cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 OOB thread index initialised to 0 could shadow earlier selections

For threads where lane_id >= data_size, the inner loop body never executes, so val stays at -inf and index is set to 0:

int index = (lane_id < data_size) ? lane_id : 0;

0 is a valid data index that may already have been placed in selected by a previous k iteration. During the XOR-reduction phase, these OOB threads participate (they shuffle -inf values, which can never win the shuffled_val > val comparison), so the final chosen_index remains correct. However, after the broadcast:

selected[k] = chosen_index;

every thread — including OOB ones — writes chosen_index to their register copy of selected, keeping all threads in sync. The algorithm is therefore correct, but initialising the fallback index to a sentinel value (e.g., -1 or data_size - 1) would make the intent clearer and avoid confusion with a real element:

Suggested change
#pragma unroll
for (int k = 0; k < K; ++k) {
CompType val = -std::numeric_limits<CompType>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
bool masked = false;
#pragma unroll
for (int j = 0; j < k; ++j) {
masked |= (selected[j] == i);
}
if (masked) continue;
CompType cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}
int index = (lane_id < data_size) ? lane_id : -1; // -1: sentinel for out-of-range lane

This is purely a readability / defensive-programming concern given that invalid index values can never propagate to the output.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment thread tests/pytorch/test_fused_router_perf.py Outdated
Comment on lines +17 to +21

seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level seed is fragile for tokens_per_expert randomness

The random seed is set once at import time. If pytest collects or runs other tests before this module's tests execute, the global random state will have advanced and test_fused_moe_aux_loss_perf_against_torch (which calls torch.randint for tokens_per_expert) will use an unknown seed. While the numerical correctness check in that test (torch.testing.assert_close(torch_loss, fused_loss)) passes regardless of the specific random values, reproducible benchmarks are easier to debug.

Consider moving the seed setup into each individual test function, or using a pytest fixture to ensure a consistent state per test.

@denera denera self-requested a review April 14, 2026 19:41

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XiaomingFun233 I went through the Greptile comments and I agree with the suggested changes except for the speedup assertion — we should keep the benchmark test in the suite for manual verification but our CI pipeline is for functional testing, not benchmarking, so the assertion does not make sense. Could you address the remaining issues and rebase the branch on latest TE/main? We can launch the CI on our end for testing.

Please also check out the contributing guidelines, particularly regarding the sign-off for your commits and the license information that needs to be added to the source files.

Thanks!

@XiaomingFun233

Copy link
Copy Markdown
Contributor Author

@XiaomingFun233 I went through the Greptile comments and I agree with the suggested changes except for the speedup assertion — we should keep the benchmark test in the suite for manual verification but our CI pipeline is for functional testing, not benchmarking, so the assertion does not make sense. Could you address the remaining issues and rebase the branch on latest TE/main? We can launch the CI on our end for testing.

Please also check out the contributing guidelines, particularly regarding the sign-off for your commits and the license information that needs to be added to the source files.

Thanks!

ok I will complete this work

- restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Add a lightweight register-based small-k path and keep the generic fallback for compatibility.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
@XiaomingFun233 XiaomingFun233 force-pushed the pr/fused-router-topk-opt branch from 066b9ea to 3ad3ad2 Compare May 15, 2026 06:17
@XiaomingFun233

Copy link
Copy Markdown
Contributor Author

Addressed the remaining issues from the review.

  • Kept the benchmark test for manual verification and removed perf-style assertions.
  • Rebased the branch onto the latest TE/main.
  • Added sign-off to the commits.

Please let me know if there is anything else you would like me to adjust

@XiaomingFun233

Copy link
Copy Markdown
Contributor Author

please checkout this new change @denera

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
@XiaomingFun233

XiaomingFun233 commented May 22, 2026

Copy link
Copy Markdown
Contributor Author

@hartsock @tabo please review all these new changes

@XiaomingFun233

Copy link
Copy Markdown
Contributor Author

@ptrendx please check out

XiaomingFun233 and others added 2 commits May 28, 2026 15:15
@XiaomingFun233

Copy link
Copy Markdown
Contributor Author

@denera Addressed the remaining review feedback.

  • Added a comment to warp_reduce_sum_float clarifying that the lane-0 broadcast is required for correctness.
  • Updated naive_topk_and_mask_smallk to use -1 as the out-of-range sentinel index for clarity.
  • Made the perf tests reproducible by resetting the random seed in each test.

The benchmark test remains opt-in for manual verification, perf-style assertions are removed, the branch has been rebased onto the latest TE/main,
and the commits are signed off.

@denera

denera commented May 28, 2026

Copy link
Copy Markdown
Collaborator

/te-ci pytorch

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, merging. Thanks!

@denera denera merged commit f3c2e74 into NVIDIA:main May 28, 2026
23 of 25 checks passed
Baibaifan pushed a commit to Baibaifan/TransformerEngine that referenced this pull request Jun 1, 2026
…rk (NVIDIA#2776)

* fused_router: keep low-risk CUDA optimizations

- restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

* fused_router: specialize naive_topk_and_mask for topk<=8

Add a lightweight register-based small-k path and keep the generic fallback for compatibility.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

* tests: add fused router performance benchmark

Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fused_router: address review feedback

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

---------

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Baibaifan pushed a commit to Baibaifan/TransformerEngine that referenced this pull request Jun 1, 2026
…rk (NVIDIA#2776)

* fused_router: keep low-risk CUDA optimizations

- restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

* fused_router: specialize naive_topk_and_mask for topk<=8

Add a lightweight register-based small-k path and keep the generic fallback for compatibility.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

* tests: add fused router performance benchmark

Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fused_router: address review feedback

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>

---------

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants