From 33c68179526e5acd129fa10de46ae7a8201f4be5 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 2 Apr 2026 06:39:12 +0000 Subject: [PATCH 01/15] optimizing grouped_topk_multi_group SYCL kernel for MoE routing Signed-off-by: xiaolong Signed-off-by: root --- csrc/moe/grouped_topk_kernels.cpp | 652 ++++++++++++++++++++++++++++++ csrc/moe/moe_ops.h | 13 + csrc/moe/torch_bindings.cpp | 10 + tests/ops/grouped_topk_op.py | 66 ++- tests/register_ops.py | 18 + tests/test_grouped_topk.py | 56 ++- 6 files changed, 777 insertions(+), 38 deletions(-) create mode 100644 csrc/moe/grouped_topk_kernels.cpp diff --git a/csrc/moe/grouped_topk_kernels.cpp b/csrc/moe/grouped_topk_kernels.cpp new file mode 100644 index 000000000..fef60a4af --- /dev/null +++ b/csrc/moe/grouped_topk_kernels.cpp @@ -0,0 +1,652 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include "../dispatch_utils.h" + +namespace vllm { +namespace moe { + +// Type trait: bfloat16 -> float for computation, everything else stays as-is +template +struct compute_type { using type = T; }; + +template <> +struct compute_type { using type = float; }; + +template +using compute_type_t = typename compute_type::type; + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +static constexpr int WARP_SIZE = 32; +static constexpr int NumNemotronExperts = 512; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = + std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); +static constexpr int MaxNumExpertsUnit = 128; +static constexpr int NumTopGroupScores = 2; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; +static constexpr int MaxNumTopGroups = 4; + +enum ScoringFunc : int { SCORING_NONE = 0, SCORING_SIGMOID = 1 }; + +template +class VllmGroupedTopKFusedKernel; + +template +class VllmGroupedTopKFusedSmallExpertCountKernel; + +template +inline T_OUT sycl_cast(T_IN val) { + return static_cast(val); +} + +template <> +inline float sycl_cast(sycl::half val) { + return static_cast(val); +} + +template <> +inline float sycl_cast(sycl::ext::oneapi::bfloat16 val) { + return static_cast(val); +} + +template +inline T neg_inf() { + return sycl_cast(-std::numeric_limits::infinity()); +} + +template +inline bool is_finite(const T val) { + return std::isfinite(sycl_cast(val)); +} + +inline float sigmoid_accurate(float x) { + return 1.f / (1.f + sycl::native::exp(-x)); // More efficient approximation Optimized point 1 +} + +template +inline T apply_sigmoid(T val) { + float f = sycl_cast(val); + return sycl_cast(sigmoid_accurate(f)); +} + +template +inline T apply_scoring(T val) { + if constexpr (SF == SCORING_NONE) { + return val; + } else if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } else { + static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID, + "Unsupported ScoringFunc in apply_scoring"); + return val; + } +} + +namespace reduce_topk { + +template +inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, + const T* in_vals, const IdxT* in_idxs, T min_val, + int topk) { + constexpr IdxT invalid_idx = std::numeric_limits::max(); + bool selected[N_IN] = {false}; + + for (int k = 0; k < topk; ++k) { + using CT = compute_type_t; + CT local_best_val = static_cast(min_val); + IdxT local_best_idx = invalid_idx; + int local_best_pos = -1; + + #pragma unroll + for (int i = 0; i < N_IN; ++i) { + if (selected[i]) { + continue; + } + T cand_val = in_vals[i]; + IdxT cand_idx = in_idxs[i]; + if ((cand_val > local_best_val) || + ((cand_val == local_best_val) && (cand_idx < local_best_idx))) { + local_best_val = cand_val; + local_best_idx = cand_idx; + local_best_pos = i; + } + } + + T warp_best_val = sycl::reduce_over_group( + subgroup, local_best_val, sycl::maximum()); + + IdxT warp_best_idx = invalid_idx; + if (local_best_pos != -1 && local_best_val == warp_best_val) { + warp_best_idx = local_best_idx; + } + warp_best_idx = sycl::reduce_over_group( + subgroup, warp_best_idx, sycl::minimum()); + + bool found = (warp_best_idx != invalid_idx); + if (found) { + int insert_pos = k; + while (insert_pos > 0 && out_val[insert_pos - 1] == warp_best_val && + out_idx[insert_pos - 1] > warp_best_idx) { + out_val[insert_pos] = out_val[insert_pos - 1]; + out_idx[insert_pos] = out_idx[insert_pos - 1]; + --insert_pos; + } + out_val[insert_pos] = warp_best_val; + out_idx[insert_pos] = warp_best_idx; + } else { + out_val[k] = min_val; + out_idx[k] = 0; + } + + if (found && local_best_pos != -1 && local_best_val == warp_best_val && + local_best_idx == warp_best_idx) { + selected[local_best_pos] = true; + } + } +} + +template +inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, + T val, IdxT idx, T min_val, int topk) { + T in_vals[1] = {val}; + IdxT in_idxs[1] = {idx}; + reduceTopK<1>(subgroup, out_val, out_idx, in_vals, in_idxs, min_val, + topk); +} + +} // namespace reduce_topk + +template +SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( + T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias, + int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, + int64_t const topk, int64_t const numExperts, + int64_t const numExpertsPerGroup, bool const renormalize, + double const routedScalingFactor, sycl::nd_item<1> item) { + + constexpr int NumWarps = MaxNumExperts / WARP_SIZE; + constexpr float invalidScoreFloat = -std::numeric_limits::infinity(); + + int threadIdx = item.get_local_id(0); + int blockIdx = item.get_group(0); + if constexpr (UseGroups){ + if (blockIdx >= numTokens) return; + } + int localSize = item.get_local_range(0); + bool has_bias = (routingBias != nullptr); + + int laneIdx = threadIdx % WARP_SIZE; + int warpIdx = threadIdx / WARP_SIZE; + + + topkValues += blockIdx * topk; + topkIndices += blockIdx * topk; + + if constexpr (UseGroups) { + auto subgroup = item.get_sub_group(); + T* scoresToken = scores + static_cast(blockIdx) * numExperts; + T selectedGroupScores[WARP_SIZE]; + int32_t selectedGroupIdx[WARP_SIZE]; + + T groupScore = neg_inf(); + if (laneIdx < numGroup) { + int32_t groupOffset = laneIdx * numExpertsPerGroup; + T largest = neg_inf(); + T secondLargest = neg_inf(); + + for (int32_t i = 0; i < numExpertsPerGroup; ++i) { + T value = apply_scoring(scoresToken[groupOffset + i]); + if (has_bias) { + value = value + sycl_cast(routingBias[groupOffset + i]); + } + if (value > largest) { + secondLargest = largest; + largest = value; + } else if (value > secondLargest) { + secondLargest = value; + } + } + groupScore = has_bias ? largest + secondLargest : largest; + } + + reduce_topk::reduceTopK( + subgroup, selectedGroupScores, selectedGroupIdx, + groupScore, laneIdx, neg_inf(), static_cast(topkGroup)); + + bool proceed = false; + if (topkGroup > 0) { + proceed = (selectedGroupScores[topkGroup - 1] != neg_inf()); + } + + if (!proceed) { + for (int i = laneIdx; i < topk; i += WARP_SIZE) { + topkIndices[i] = static_cast(i); + topkValues[i] = 1.0f / static_cast(topk); + } + return; + } + + constexpr int MaxExpertCandidatesPerLane = NumDeepseekExperts / WARP_SIZE; + T localCandidateScores[MaxExpertCandidatesPerLane]; + IdxT localCandidateIdx[MaxExpertCandidatesPerLane]; + T selectedExpertScores[DefaultMaxNumTopExperts]; + IdxT selectedExpertIdx[DefaultMaxNumTopExperts]; + + for (int i = 0; i < MaxExpertCandidatesPerLane; ++i) { + localCandidateScores[i] = neg_inf(); + localCandidateIdx[i] = 0; + } + + int32_t totalCandidates = topkGroup * numExpertsPerGroup; + for (int32_t candidate = laneIdx; candidate < totalCandidates; + candidate += WARP_SIZE) { + int32_t localSlot = candidate / WARP_SIZE; + int32_t selectedGroup = candidate / numExpertsPerGroup; + int32_t expertInGroup = candidate % numExpertsPerGroup; + int32_t gid = selectedGroupIdx[selectedGroup]; + int32_t idx = gid * numExpertsPerGroup + expertInGroup; + T candidateScore = neg_inf(); + + T input = scoresToken[idx]; + if (is_finite(input)) { + T score = apply_scoring(input); + candidateScore = score; + if (has_bias) { + candidateScore = candidateScore + sycl_cast(routingBias[idx]); + } + } + + localCandidateScores[localSlot] = candidateScore; + localCandidateIdx[localSlot] = static_cast(idx); + } + + reduce_topk::reduceTopK( + subgroup, selectedExpertScores, selectedExpertIdx, + localCandidateScores, localCandidateIdx, neg_inf(), static_cast(topk)); + + for (int i = 1; i < topk; ++i) { + T score = selectedExpertScores[i]; + IdxT idx = selectedExpertIdx[i]; + int j = i; + while (j > 0 && + ((selectedExpertScores[j - 1] < score) || + ((selectedExpertScores[j - 1] == score) && + (selectedExpertIdx[j - 1] > idx)))) { + selectedExpertScores[j] = selectedExpertScores[j - 1]; + selectedExpertIdx[j] = selectedExpertIdx[j - 1]; + --j; + } + selectedExpertScores[j] = score; + selectedExpertIdx[j] = idx; + } + + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = selectedExpertIdx[laneIdx]; + T in = scoresToken[static_cast(laneIdxOut)]; + laneUnbiased = sycl_cast(apply_scoring(in)); + } + + float scale = static_cast(routedScalingFactor); + if (renormalize) { + float topkSum = 1e-20f; + topkSum += sycl::reduce_over_group( + subgroup, laneUnbiased,sycl::plus()); + scale /= topkSum; + } + + if (laneIdx < topk) { + topkIndices[laneIdx] = laneIdxOut; + topkValues[laneIdx] = laneUnbiased * scale; + } + return; + } else { + + float* smemScoreSigmoid = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + float* smemScoreBias = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + float topScores[MaxNumTopExperts] = {invalidScoreFloat}; + int32_t topExperts[MaxNumTopExperts] = {0}; + float expertScoreGroup[MaxNumTopGroups] = {invalidScoreFloat}; + int32_t expertIdxGroup[MaxNumTopGroups] = {0}; + auto group = item.get_sub_group(); + + for (int expert = threadIdx; expert < numExperts; expert += localSize) { + int64_t scoreIdx = int64_t{blockIdx} * int64_t{numExperts} + expert; + float score = sycl_cast(scores[scoreIdx]); + float scoreSigmoid = apply_scoring(score); + smemScoreSigmoid[expert] = scoreSigmoid; + smemScoreBias[expert] = has_bias + ? (scoreSigmoid + sycl_cast(routingBias[expert])) + : scoreSigmoid; + } + + if constexpr (MaxNumExperts > MaxNumExpertsUnit) { + constexpr int NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; + constexpr int NumInterTopK = NumExpertWarps * MaxNumTopExperts; + float* smemInterTopScores = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + int32_t* smemInterTopExperts = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + + if (warpIdx < NumExpertWarps) { + int offset = warpIdx * WARP_SIZE * MaxNumTopGroups; + + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + int expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = (offset + expertIdx < numExperts) + ? smemScoreBias[offset + expertIdx] + : invalidScoreFloat; + } + reduce_topk::reduceTopK( + group, topScores, topExperts, expertScoreGroup, expertIdxGroup, + invalidScoreFloat, static_cast(topk)); + + if (laneIdx < MaxNumTopExperts) { + if (laneIdx < topk) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } else { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1; + } + } + } + item.barrier(sycl::access::fence_space::local_space); + if (warpIdx == 0) { + constexpr int NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; + + for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) { + int ii = i / WARP_SIZE; + if (i < NumInterTopK) { + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; + } else { + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = MaxNumExperts - 1; + } + } + + reduce_topk::reduceTopK( + group, topScores, topExperts, intermediateScore, intermediateExpert, + invalidScoreFloat, static_cast(topk)); + } + } else { + if (warpIdx == 0) { + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + int32_t expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = (expertIdx < numExperts) + ? smemScoreBias[expertIdx] + : invalidScoreFloat; + } + reduce_topk::reduceTopK( + group, topScores, topExperts, expertScoreGroup, expertIdxGroup, + invalidScoreFloat, static_cast(topk)); + } + } + + if (warpIdx == 0) { + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + float finalScore = static_cast(scoreNorm * routedScalingFactor); + float topk_sum = 1e-20f; + if (renormalize) { + topk_sum += sycl::reduce_over_group(group, scoreNorm,sycl::plus()); + finalScore /= topk_sum; + } + if (laneIdx < topk) { + topkIndices[laneIdx] = finalScore; + topkValues[laneIdx] = expertIdx; + } + } + } // end if constexpr (!UseGroups) +} + +template +void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices, + BiasT const* bias, int64_t const num_tokens, + int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, + bool enable_pdl = false, sycl::queue queue = sycl::queue()) { + int64_t experts_per_group = num_experts / n_group; + bool is_single_group = + (n_group == 1) && (topk_group == 1) && + (num_experts <= MaxSupportedExpertCount) && + (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); + + #define LAUNCH_SMALL_KERNEL(MAX_EXPERTS, USE_GROUPS, MAX_TOP_EXPERTS, NUM_THREADS) \ + do { \ + size_t local_size = static_cast(NUM_THREADS); \ + size_t global_size = static_cast(num_tokens) * local_size; \ + queue.submit([&](sycl::handler& cgh) { \ + cgh.parallel_for>( \ + sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), \ + [=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { \ + grouped_topk_fused_small_expert_count_kernel( \ + scores, topk_values, topk_indices, bias, \ + num_tokens, n_group, topk_group, topk, num_experts, \ + experts_per_group, renormalize, routed_scaling_factor, item); \ + }); \ + }); \ + } while (0) + + if (is_single_group) { + if (num_experts == NumNemotronExperts && n_group == 1 && + topk == MaxSupportedTopExperts) { + LAUNCH_SMALL_KERNEL(NumNemotronExperts, false, + MaxSupportedTopExperts, + ((NumNemotronExperts + MaxNumExpertsUnit - 1) / + MaxNumExpertsUnit) * WARP_SIZE); + } else if (num_experts > NumKimiK2Experts && + num_experts <= MaxSupportedExpertCount) { + LAUNCH_SMALL_KERNEL(MaxSupportedExpertCount, false, + DefaultMaxNumTopExperts, + ((MaxSupportedExpertCount + MaxNumExpertsUnit - 1) / + MaxNumExpertsUnit) * WARP_SIZE); + } else if (num_experts > MaxNumExpertsUnit && + num_experts <= NumKimiK2Experts) { + LAUNCH_SMALL_KERNEL(NumKimiK2Experts, false, + DefaultMaxNumTopExperts, + ((NumKimiK2Experts + MaxNumExpertsUnit - 1) / + MaxNumExpertsUnit) * WARP_SIZE); + } else { + LAUNCH_SMALL_KERNEL(MaxNumExpertsUnit, false, + DefaultMaxNumTopExperts, + WARP_SIZE); + } + } else { + LAUNCH_SMALL_KERNEL(NumDeepseekExperts, true, + DefaultMaxNumTopExperts, + WARP_SIZE); + } + + #undef LAUNCH_SMALL_KERNEL + +} + +#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \ + template void invokeNoAuxTc( \ + T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \ + int64_t const num_tokens, int64_t const num_experts, \ + int64_t const n_group, int64_t const topk_group, int64_t const topk, \ + bool const renormalize, double const routed_scaling_factor, \ + bool enable_pdl, sycl::queue queue); + +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +} // end namespace moe +} // namespace vllm + +std::tuple grouped_topk_multi_group( + torch::Tensor const& hidden_states, + torch::Tensor const& gating_output, + int64_t const n_topk, + bool const renormalize, + int64_t const n_expert_group, + int64_t const n_topk_group, + c10::string_view const scoring_func, + double const routed_scaling_factor, + c10::optional const& bias) { + auto data_type = gating_output.scalar_type(); + bool has_bias = bias.has_value() && bias->defined(); + auto bias_type = has_bias ? bias->scalar_type() : torch::kFloat32; + auto input_size = gating_output.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + int64_t n_group = n_expert_group; + int64_t topk_group = n_topk_group; + int64_t topk = n_topk; + + TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0], + "Number of tokens mismatch"); + TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor"); + TORCH_CHECK(n_group > 0, "n_group must be positive"); + TORCH_CHECK(topk > 0, "topk must be positive"); + TORCH_CHECK(topk_group > 0, "topk_group must be positive"); + TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group"); + TORCH_CHECK(num_experts % n_group == 0, + "num_experts should be divisible by n_group"); + TORCH_CHECK(n_group <= 32, + "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= topk_group * (num_experts / n_group), + "topk must be <= topk_group * (num_experts / n_group)"); + TORCH_CHECK(scoring_func == "sigmoid" || scoring_func == "softmax", + "Unsupported scoring_func: ", scoring_func); + // Pre-apply softmax on host side then use SCORING_NONE in kernel, + // because softmax requires a full-row reduction that the kernel doesn't support. + torch::Tensor scores; + vllm::moe::ScoringFunc sf; + if (scoring_func == "sigmoid") { + sf = vllm::moe::SCORING_SIGMOID; + scores = gating_output; + } else if (scoring_func == "softmax") { + sf = vllm::moe::SCORING_NONE; + scores = torch::softmax(gating_output, /*dim=*/-1); + } + + // Always output float32 for topk_values (eliminates Python-side conversion) + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kFloat32).device(gating_output.device())); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kInt32).device(gating_output.device())); + + auto device_idx = gating_output.device().index(); + auto stream = c10::xpu::getCurrentXPUStream(device_idx).queue(); + +#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \ + do { \ + switch (sf) { \ + case vllm::moe::SCORING_NONE: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(scores.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, false, stream); \ + break; \ + case vllm::moe::SCORING_SIGMOID: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(scores.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, false, stream); \ + break; \ + default: \ + throw std::invalid_argument("Unsupported scoring_func"); \ + break; \ + } \ + } while (0) + +#define LAUNCH_KERNEL(T, IdxT) \ + do{ \ + switch (bias_type) { \ + case torch::kFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::half, IdxT); \ + break; \ + case torch::kFloat32: \ + LAUNCH_KERNEL_SF(T, float, IdxT); \ + break; \ + case torch::kBFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::ext::oneapi::bfloat16, IdxT); \ + break; \ + default: \ + throw std::invalid_argument( \ + "Invalid bias dtype, only supports float16, float32, and " \ + "bfloat16"); \ + break; \ + } \ + } \ + while (0) + + + switch (data_type) { + case torch::kFloat16: + LAUNCH_KERNEL(sycl::half, int32_t); + break; + case torch::kFloat32: + LAUNCH_KERNEL(float, int32_t); + break; + case torch::kBFloat16: + LAUNCH_KERNEL(sycl::ext::oneapi::bfloat16, int32_t); + break; + default: + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } +#undef LAUNCH_KERNEL +#undef LAUNCH_KERNEL_SF + return {topk_values, topk_indices}; +} \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 9061839a2..88f43762e 100755 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -56,6 +56,19 @@ std::tuple fused_grouped_topk( const double routed_scaling_factor, const c10::optional& bias); +std::tuple grouped_topk_multi_group( + const torch::Tensor& hidden_states, + const torch::Tensor& gating_output, + const int64_t n_topk, + const bool renormalize, + const int64_t n_expert_group, + const int64_t n_topk_group, + const c10::string_view scoring_func, + const double routed_scaling_factor, + const c10::optional& bias); + + + void topk_softmax( torch::Tensor& topk_weights, torch::Tensor& topk_indices, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 47eb1d61d..0f44d34ed 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -47,6 +47,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor? maybe_expert_map) -> () "); m.impl("moe_lora_align_block_size", torch::kXPU, &moe_lora_align_block_size); + // Apply grouped topk routing to select experts. m.def( "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " @@ -62,6 +63,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "scoring_func, float routed_scaling_factor, Tensor? bias=None) -> " "(Tensor, Tensor)"); m.impl("fused_grouped_topk", torch::kXPU, &fused_grouped_topk); + + // Fused Grouped TopK (multi-group optimized path) + m.def( + "grouped_topk_multi_group(Tensor hidden_states, Tensor gating_output, int " + "n_topk, " + "bool renormalize, int n_expert_group, int n_topk_group, str " + "scoring_func, float routed_scaling_factor, Tensor? bias=None) -> " + "(Tensor, Tensor)"); + m.impl("grouped_topk_multi_group", torch::kXPU, &grouped_topk_multi_group); // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " diff --git a/tests/ops/grouped_topk_op.py b/tests/ops/grouped_topk_op.py index 40fca3a98..6c31e1a4f 100644 --- a/tests/ops/grouped_topk_op.py +++ b/tests/ops/grouped_topk_op.py @@ -20,50 +20,60 @@ def grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + # Move to CPU to avoid XPU OOM on intermediate tensors if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": scores = gating_output.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, - sorted=False)[1] # [n, top_k_group] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = True + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ + 1 + ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=use_sorted + ) + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights * routed_scaling_factor + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -107,3 +117,21 @@ def fused_grouped_topk_sycl( renormalize, num_expert_group, topk_group, scoring_func, routed_scaling_factor, e_score_correction_bias) + + +def grouped_topk_multi_group( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + return ops.grouped_topk_multi_group(hidden_states, gating_output, topk, + renormalize, num_expert_group, + topk_group, scoring_func, + routed_scaling_factor, + e_score_correction_bias) diff --git a/tests/register_ops.py b/tests/register_ops.py index 0508ed06d..b6765efd1 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -473,6 +473,24 @@ def fused_grouped_topk( routed_scaling_factor, e_score_correction_bias) +def grouped_topk_multi_group( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, +): + return torch.ops._moe_C.grouped_topk_multi_group(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group, + scoring_func, + routed_scaling_factor, + e_score_correction_bias) + def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py index ec383a8a0..7eb904ff4 100644 --- a/tests/test_grouped_topk.py +++ b/tests/test_grouped_topk.py @@ -3,7 +3,7 @@ import torch from tests.ops.grouped_topk_op import (fused_grouped_topk, - fused_grouped_topk_sycl, grouped_topk) + fused_grouped_topk_sycl, grouped_topk,grouped_topk_multi_group) from tests.utils import seed_everything #override pytest parameters when enable mini pytest @@ -18,14 +18,14 @@ } -@pytest.mark.parametrize("n_token", [1, 33, 64]) +@pytest.mark.parametrize("n_token", [1, 33, 64, 50000,100000]) @pytest.mark.parametrize("n_hidden", [1024, 2048]) -@pytest.mark.parametrize("n_expert", [16]) -@pytest.mark.parametrize("topk", [2]) -@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("n_expert", [256]) +@pytest.mark.parametrize("topk", [8]) +@pytest.mark.parametrize("renormalize", [False,True]) @pytest.mark.parametrize("num_expert_group", [8]) -@pytest.mark.parametrize("topk_group", [2]) -@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("topk_group", [4]) +@pytest.mark.parametrize("scoring_func", ["sigmoid",'softmax']) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @@ -72,22 +72,40 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) + test_topk_weights_multi_group, test_topk_ids_multi_group = grouped_topk_multi_group( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) if renormalize: + # torch.testing.assert_close(baseline_topk_weights, + # test_topk_weights, + # atol=2e-2, + # rtol=0) + # torch.testing.assert_close(baseline_topk_weights, + # test_topk_weights_sycl, + # atol=2e-2, + # rtol=0) torch.testing.assert_close(baseline_topk_weights, - test_topk_weights, - atol=2e-2, - rtol=0) - torch.testing.assert_close(baseline_topk_weights, - test_topk_weights_sycl, + test_topk_weights_multi_group, atol=2e-2, rtol=0) + # torch.testing.assert_close(baseline_topk_ids, + # test_topk_ids, + # atol=0, + # rtol=0) + # torch.testing.assert_close(baseline_topk_ids, + # test_topk_ids_sycl, + # atol=0, + # rtol=0) torch.testing.assert_close(baseline_topk_ids, - test_topk_ids, - atol=0, - rtol=0) - torch.testing.assert_close(baseline_topk_ids, - test_topk_ids_sycl, - atol=0, - rtol=0) + test_topk_ids_multi_group, + atol=0, + rtol=0) \ No newline at end of file From cd261ddedc99863142d639e9c60f240678ed7a08 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 3 Apr 2026 08:41:47 +0000 Subject: [PATCH 02/15] optimizing `fused_grouped_topk` SYCL kernel for MoE expert routing Signed-off-by: root Signed-off-by: xiaolong Signed-off-by: root --- csrc/moe/fused_grouped_topk.cpp | 1069 ++++++++++++++++------------- csrc/moe/grouped_topk_kernels.cpp | 652 ------------------ csrc/moe/moe_ops.h | 12 - csrc/moe/torch_bindings.cpp | 26 +- tests/ops/grouped_topk_op.py | 26 +- tests/register_ops.py | 18 - tests/test_grouped_topk.py | 55 +- 7 files changed, 643 insertions(+), 1215 deletions(-) delete mode 100644 csrc/moe/grouped_topk_kernels.cpp diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index 49471084e..54705c41d 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -1,511 +1,648 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include #include - -#include "../utils.h" +#include +#include +#include +#include #include "../dispatch_utils.h" namespace vllm { -namespace GroupedTopKImpl { - -enum class ScoringFunc { - DEFAULT = 0, - SOFTMAX = 1, - SIGMOID = 2, -}; - -template -struct Fused_Grouped_Topk { - static constexpr int sub_group_size = 32; - static constexpr int max_group_size = 1024; - static constexpr int malloc_per_item = MAX_EXPERT_GROUPS; - static constexpr float kNegInfinity = INFINITY * -1; - - Fused_Grouped_Topk( - float* topk_weights, - int* topk_ids, - const T* gating_output, - const T* e_score_correction_bias, - const double routed_scaling_factor, - const ScoringFunc scoring_mode, - const bool renormalize, - const int tokens, - const int experts, - const int top_k, - const int num_expert_group, - const int topk_group) - : topk_weights(topk_weights), - topk_ids(topk_ids), - gating_output(gating_output), - e_score_correction_bias(e_score_correction_bias), - routed_scaling_factor(routed_scaling_factor), - scoring_mode(scoring_mode), - renormalize(renormalize), - tokens(tokens), - experts(experts), - top_k(top_k), - num_expert_group(num_expert_group), - topk_group(topk_group) {} - - static inline sycl::nd_range<3> - get_nd_range(const int tokens, const int experts) { - int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; - int group_size = (experts + calc_per_item - 1) / calc_per_item; - group_size = group_size < sub_group_size ? sub_group_size : group_size; - group_size = group_size < max_group_size ? group_size : max_group_size; - int sub_groups_per_group = - (group_size + sub_group_size - 1) / sub_group_size; - group_size = sub_groups_per_group * sub_group_size; - int global_size = - (tokens + sub_groups_per_group - 1) / sub_groups_per_group; - - sycl::range<3> local(1, 1, group_size); - sycl::range<3> global(1, 1, global_size); - return sycl::nd_range<3>(global * local, local); - } +namespace moe { - static inline float Sigmoid(float x) { - return 1.0f / (1.0f + sycl::native::exp(-x)); - } +// Type trait: bfloat16 -> float for computation, everything else stays as-is +template +struct compute_type { using type = T; }; - [[sycl::reqd_sub_group_size(sub_group_size)]] void - operator()(sycl::nd_item<3> item) const { - int group_id = item.get_group_linear_id(); - int local_range = item.get_local_range(2); - int sub_groups_per_group = local_range / sub_group_size; - int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; +template <> +struct compute_type { using type = float; }; - int experts_per_group = experts / num_expert_group; +template +using compute_type_t = typename compute_type::type; + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +static constexpr int WARP_SIZE = 32; +static constexpr int NumNemotronExperts = 512; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = + std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); +static constexpr int MaxNumExpertsUnit = 128; +static constexpr int NumTopGroupScores = 2; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; +static constexpr int MaxNumTopGroups = 4; + +enum ScoringFunc : int { SCORING_NONE = 0, SCORING_SIGMOID = 1 }; + +template +class VllmGroupedTopKFusedKernel; + +template +class VllmGroupedTopKFusedSmallExpertCountKernel; + +template +inline T_OUT sycl_cast(T_IN val) { + return static_cast(val); +} - sycl::sub_group sg = item.get_sub_group(); - int sg_id = sg.get_group_id(); - int sg_local_id = sg.get_local_id(); - int tid = group_id * sub_groups_per_group + sg_id; +template <> +inline float sycl_cast(sycl::half val) { + return static_cast(val); +} - if (tid >= tokens) { - return; // Out of bounds - } +template <> +inline float sycl_cast(sycl::ext::oneapi::bfloat16 val) { + return static_cast(val); +} - T load_elems[malloc_per_item]; - int local_idx[malloc_per_item]; - T bias[malloc_per_item]; +template +inline T neg_inf() { + return sycl_cast(-std::numeric_limits::infinity()); +} - int start_offset = sg_local_id * calc_per_item; - int local_num = calc_per_item; +template +inline bool is_finite(const T val) { + return std::isfinite(sycl_cast(val)); +} +inline float sigmoid_accurate(float x) { + return 1.f / (1.f + sycl::native::exp(-x)); // More efficient approximation Optimized point 1 +} - if (start_offset + local_num >= experts) { - local_num = experts - start_offset; - if (local_num < 0) { - local_num = 0; // No elements to process - } - } +template +inline T apply_sigmoid(T val) { + float f = sycl_cast(val); + return sycl_cast(sigmoid_accurate(f)); - for (int e = 0; e < calc_per_item; ++e) { - load_elems[e] = kNegInfinity; - local_idx[e] = -1; - bias[e] = 0.0f; // Initialize bias to zero - } +} - for (int e = 0; e < local_num; ++e) { - load_elems[e] = gating_output[tid * experts + start_offset + e]; - } +template +inline T apply_scoring(T val) { + if constexpr (SF == SCORING_NONE) { + return val; + } else if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } +} - T local_elems[malloc_per_item]; +namespace reduce_topk { + +template +inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, + const T* in_vals, const IdxT* in_idxs, T min_val, + int topk) { + constexpr IdxT invalid_idx = std::numeric_limits::max(); + bool selected[N_IN] = {false}; + + for (int k = 0; k < topk; ++k) { + using CT = compute_type_t; + CT local_best_val = static_cast(min_val); + IdxT local_best_idx = invalid_idx; + int local_best_pos = -1; + + #pragma unroll + for (int i = 0; i < N_IN; ++i) { + if (selected[i]) { + continue; + } + T cand_val = in_vals[i]; + IdxT cand_idx = in_idxs[i]; + if ((cand_val > local_best_val) || + ((cand_val == local_best_val) && (cand_idx < local_best_idx))) { + local_best_val = cand_val; + local_best_idx = cand_idx; + local_best_pos = i; + } + } - for (int e = 0; e < local_num; ++e) { - local_elems[e] = load_elems[e]; - local_idx[e] = start_offset + e; - } + T warp_best_val = sycl::reduce_over_group( + subgroup, local_best_val, sycl::maximum()); - if (scoring_mode == ScoringFunc::SOFTMAX) { - float softmax_max = kNegInfinity; - for (int e = 0; e < local_num; ++e) { - float s = load_elems[e]; - softmax_max = (softmax_max > s) ? softmax_max : s; - } - for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { - float other_val = sycl::permute_group_by_xor(sg, softmax_max, offset); - softmax_max = (softmax_max > other_val) ? softmax_max : other_val; - } - float softmax_sum = 0.0f; - for (int e = 0; e < local_num; ++e) { - float s = local_elems[e]; - softmax_sum += sycl::native::exp(s - softmax_max); - } - for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { - float other_val = sycl::permute_group_by_xor(sg, softmax_sum, offset); - softmax_sum += other_val; - } - for (int e = 0; e < local_num; ++e) { - float s = local_elems[e]; - local_elems[e] = sycl::native::exp(s - softmax_max) / softmax_sum; - } - } else if (scoring_mode == ScoringFunc::SIGMOID) { - for (int e = 0; e < local_num; ++e) { - float s = load_elems[e]; - load_elems[e] = Sigmoid(s); - } - for (int e = 0; e < local_num; ++e) { - local_elems[e] = load_elems[e]; - } - } + IdxT warp_best_idx = invalid_idx; + if (local_best_pos != -1 && local_best_val == warp_best_val) { + warp_best_idx = local_best_idx; + } + warp_best_idx = sycl::reduce_over_group( + subgroup, warp_best_idx, sycl::minimum()); + + bool found = (warp_best_idx != invalid_idx); + if (found) { + int insert_pos = k; + while (insert_pos > 0 && out_val[insert_pos - 1] == warp_best_val && + out_idx[insert_pos - 1] > warp_best_idx) { + out_val[insert_pos] = out_val[insert_pos - 1]; + out_idx[insert_pos] = out_idx[insert_pos - 1]; + --insert_pos; + } + out_val[insert_pos] = warp_best_val; + out_idx[insert_pos] = warp_best_idx; + } else { + out_val[k] = min_val; + out_idx[k] = 0; + } - bool has_bias = e_score_correction_bias != nullptr; - if (has_bias) { - for (int e = 0; e < local_num; ++e) { - bias[e] = e_score_correction_bias[start_offset + e]; - } + if (found && local_best_pos != -1 && local_best_val == warp_best_val && + local_best_idx == warp_best_idx) { + selected[local_best_pos] = true; + } } +} - // perform topk_group groups - // 1 calculate each group scores - float group_scores[malloc_per_item * 2]; - for (int i = 0; i < num_expert_group * 2; ++i) { - group_scores[i] = kNegInfinity; - } - for (int i = 0; i < local_num; ++i) { - float b = bias[i]; - float score = local_elems[i] + b; - int i_group = (calc_per_item * sg_local_id + i) / experts_per_group; - float group_max = group_scores[i_group]; - float group_next_max = group_scores[num_expert_group + i_group]; - if (score > group_max) { - group_next_max = group_max; - group_max = score; - } else if (score > group_next_max) { - group_next_max = score; - } - group_scores[i_group] = group_max; - group_scores[num_expert_group + i_group] = group_next_max; - } - for (int i = 0; i < num_expert_group; ++i) { - float group_max = group_scores[i]; - float group_next_max = group_scores[num_expert_group + i]; - - float max1 = sycl::reduce_over_group( - sg, sycl::max(group_max, group_next_max), sycl::maximum<>()); - float local_second = - (group_max < max1 && group_max > -INFINITY) ? group_max : -INFINITY; - local_second = (group_next_max < max1 && group_next_max > local_second) - ? group_next_max - : local_second; - float max2 = sycl::reduce_over_group(sg, local_second, sycl::maximum<>()); - group_scores[i] = max1 + (has_bias ? max2 : 0.0f); +template +inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, + T val, IdxT idx, T min_val, int topk) { + T in_vals[1] = {val}; + IdxT in_idxs[1] = {idx}; + reduceTopK<1>(subgroup, out_val, out_idx, in_vals, in_idxs, min_val, + topk); +} + +} // namespace reduce_topk + +template +SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( + T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias, + int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, + int64_t const topk, int64_t const numExperts, + int64_t const numExpertsPerGroup, bool const renormalize, + double const routedScalingFactor, sycl::nd_item<1> item) { + + constexpr int NumWarps = MaxNumExperts / WARP_SIZE; + constexpr float invalidScoreFloat = -std::numeric_limits::infinity(); + + int threadIdx = item.get_local_id(0); + int blockIdx = item.get_group(0); + if constexpr (UseGroups){ + if (blockIdx >= numTokens) return; } + int localSize = item.get_local_range(0); + bool has_bias = (routingBias != nullptr); + + int laneIdx = threadIdx % WARP_SIZE; + int warpIdx = threadIdx / WARP_SIZE; + + + topkValues += blockIdx * topk; + topkIndices += blockIdx * topk; + + if constexpr (UseGroups) { + auto subgroup = item.get_sub_group(); + T* scoresToken = scores + static_cast(blockIdx) * numExperts; + T selectedGroupScores[WARP_SIZE]; + int32_t selectedGroupIdx[WARP_SIZE]; + + T groupScore = neg_inf(); + if (laneIdx < numGroup) { + int32_t groupOffset = laneIdx * numExpertsPerGroup; + T largest = neg_inf(); + T secondLargest = neg_inf(); + + for (int32_t i = 0; i < numExpertsPerGroup; ++i) { + T value = apply_scoring(scoresToken[groupOffset + i]); + if (has_bias) { + value = value + sycl_cast(routingBias[groupOffset + i]); + } + if (value > largest) { + secondLargest = largest; + largest = value; + } else if (value > secondLargest) { + secondLargest = value; + } + } + groupScore = has_bias ? largest + secondLargest : largest; + } - // 2 find topk_group groups as kNegInfinity - int group_topk_idx[malloc_per_item]; - for (int k = 0; k < topk_group; ++k) { - float k_max = group_scores[0]; - int k_max_idx = 0; - for (int e = 1; e < num_expert_group; ++e) { - float score = group_scores[e]; - - if (score > k_max) { - k_max = score; - k_max_idx = e; + reduce_topk::reduceTopK( + subgroup, selectedGroupScores, selectedGroupIdx, + groupScore, laneIdx, neg_inf(), static_cast(topkGroup)); + + bool proceed = false; + if (topkGroup > 0) { + proceed = (selectedGroupScores[topkGroup - 1] != neg_inf()); } - } - group_scores[k_max_idx] = kNegInfinity; - group_topk_idx[k] = k_max_idx; - } - // 3 mask no-topk_group groups - for (int i = 0; i < calc_per_item; ++i) { - bool is_masked = true; - for (int k = 0; k < topk_group; ++k) { - if ((local_idx[i] / experts_per_group) == group_topk_idx[k]) { - is_masked = false; - break; + if (!proceed) { + for (int i = laneIdx; i < topk; i += WARP_SIZE) { + topkIndices[i] = static_cast(i); + topkValues[i] = 1.0f / static_cast(topk); + } + return; + } + + constexpr int MaxExpertCandidatesPerLane = NumDeepseekExperts / WARP_SIZE; + T localCandidateScores[MaxExpertCandidatesPerLane]; + IdxT localCandidateIdx[MaxExpertCandidatesPerLane]; + T selectedExpertScores[DefaultMaxNumTopExperts]; + IdxT selectedExpertIdx[DefaultMaxNumTopExperts]; + + for (int i = 0; i < MaxExpertCandidatesPerLane; ++i) { + localCandidateScores[i] = neg_inf(); + localCandidateIdx[i] = 0; } - } - if (is_masked) { - local_elems[i] = kNegInfinity; - } - } - // Perform top-k selection - T topk_weights_local[malloc_per_item]; - int topk_ids_local[malloc_per_item]; - - for (int k = 0; k < top_k; ++k) { - float k_max = kNegInfinity; - int k_max_idx = -1; - int remove_ix = -1; - for (int e = 0; e < calc_per_item; ++e) { - float le = local_elems[e]; - float b = bias[e]; - float my_val = le + b; - int my_idx = local_idx[e]; - for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { - float other_val = sycl::permute_group_by_xor(sg, my_val, offset); - int other_idx = sycl::permute_group_by_xor(sg, my_idx, offset); - if (other_val > my_val || - (other_val == my_val && other_idx < my_idx)) { - my_val = other_val; - my_idx = other_idx; - } + int32_t totalCandidates = topkGroup * numExpertsPerGroup; + for (int32_t candidate = laneIdx; candidate < totalCandidates; + candidate += WARP_SIZE) { + int32_t localSlot = candidate / WARP_SIZE; + int32_t selectedGroup = candidate / numExpertsPerGroup; + int32_t expertInGroup = candidate % numExpertsPerGroup; + int32_t gid = selectedGroupIdx[selectedGroup]; + int32_t idx = gid * numExpertsPerGroup + expertInGroup; + T candidateScore = neg_inf(); + + T input = scoresToken[idx]; + if (is_finite(input)) { + T score = apply_scoring(input); + candidateScore = score; + if (has_bias) { + candidateScore = candidateScore + sycl_cast(routingBias[idx]); + } + } + + localCandidateScores[localSlot] = candidateScore; + localCandidateIdx[localSlot] = static_cast(idx); } - if (my_val > k_max || (my_val == k_max && my_idx < k_max_idx)) { - k_max = my_val; - k_max_idx = my_idx; - - if (k_max_idx == local_idx[e]) { - remove_ix = e; // Mark this index for removal - } else - remove_ix = -1; + + reduce_topk::reduceTopK( + subgroup, selectedExpertScores, selectedExpertIdx, + localCandidateScores, localCandidateIdx, neg_inf(), static_cast(topk)); + + for (int i = 1; i < topk; ++i) { + T score = selectedExpertScores[i]; + IdxT idx = selectedExpertIdx[i]; + int j = i; + while (j > 0 && + ((selectedExpertScores[j - 1] < score) || + ((selectedExpertScores[j - 1] == score) && + (selectedExpertIdx[j - 1] > idx)))) { + selectedExpertScores[j] = selectedExpertScores[j - 1]; + selectedExpertIdx[j] = selectedExpertIdx[j - 1]; + --j; + } + selectedExpertScores[j] = score; + selectedExpertIdx[j] = idx; + } + + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = selectedExpertIdx[laneIdx]; + T in = scoresToken[static_cast(laneIdxOut)]; + laneUnbiased = sycl_cast(apply_scoring(in)); + } + + float scale = static_cast(routedScalingFactor); + if (renormalize) { + float topkSum = 1e-20f; + topkSum += sycl::reduce_over_group( + subgroup, laneUnbiased,sycl::plus()); + scale /= topkSum; + } + + if (laneIdx < topk) { + topkIndices[laneIdx] = laneIdxOut; + topkValues[laneIdx] = laneUnbiased * scale; } - } - - int select_item = k_max_idx / calc_per_item; - int select_elem = k_max_idx % calc_per_item; - k_max = local_elems[select_elem]; - k_max = sycl::group_broadcast(sg, k_max, select_item); - if (remove_ix != -1) { - local_elems[remove_ix] = - kNegInfinity; // Reset the score to avoid re-selection - local_idx[remove_ix] = -1; - remove_ix = -1; - } - - topk_weights_local[k] = k_max; - topk_ids_local[k] = k_max_idx < 0 ? k : k_max_idx; + return; + } else { + + T* smemScoreSigmoid = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + T* smemScoreBias = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + T invalidScoreT = neg_inf(); + T topScores[MaxNumTopExperts] = {neg_inf()}; + int32_t topExperts[MaxNumTopExperts] = {0}; + T expertScoreGroup[MaxNumTopGroups] = {neg_inf()}; + int32_t expertIdxGroup[MaxNumTopGroups] = {0}; + auto group = item.get_sub_group(); + + for (int expert = threadIdx; expert < numExperts; expert += localSize) { + int64_t scoreIdx = int64_t{blockIdx} * int64_t{numExperts} + expert; + T score = scores[scoreIdx]; + T scoreSigmoid = apply_scoring(score); + smemScoreSigmoid[expert] = scoreSigmoid; + smemScoreBias[expert] = has_bias + ? (scoreSigmoid + sycl_cast(routingBias[expert])) + : scoreSigmoid; } - if (renormalize) { - // Renormalize the top-k weights - float sum = 0; - for (int i = 0; i < top_k; ++i) { - sum += topk_weights_local[i]; - } - if (sum > 0) { - for (int i = 0; i < top_k; ++i) { - topk_weights_local[i] /= sum; + // Barrier: ensure all warps have written smemScoreSigmoid/smemScoreBias + // before any warp reads them in the topk reduction below. + item.barrier(sycl::access::fence_space::local_space); + + if constexpr (MaxNumExperts > MaxNumExpertsUnit) { + constexpr int NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; + constexpr int NumInterTopK = NumExpertWarps * MaxNumTopExperts; + T* smemInterTopScores = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + IdxT* smemInterTopExperts = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + + if (warpIdx < NumExpertWarps) { + int32_t offset = warpIdx * WARP_SIZE * MaxNumTopGroups; + + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + int expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = (offset + expertIdx < numExperts) + ? smemScoreBias[offset + expertIdx] + : invalidScoreT; + } + reduce_topk::reduceTopK( + group, topScores, topExperts, expertScoreGroup, expertIdxGroup, + invalidScoreT, static_cast(topk)); + + if (laneIdx < MaxNumTopExperts) { + if (laneIdx < topk) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } else { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreT; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1; + } + } + } + item.barrier(sycl::access::fence_space::local_space); + if (warpIdx == 0) { + constexpr int NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; + T intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; + T invalidScoreT = neg_inf(); + + for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) { + int ii = i / WARP_SIZE; + if (i < NumInterTopK) { + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; + } else { + intermediateScore[ii] = invalidScoreT; + intermediateExpert[ii] = MaxNumExperts - 1; + } + } + + reduce_topk::reduceTopK( + group, topScores, topExperts, intermediateScore, intermediateExpert, + invalidScoreT, static_cast(topk)); + } + } else { + if (warpIdx == 0) { + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + int32_t expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = (expertIdx < numExperts) + ? smemScoreBias[expertIdx] + : invalidScoreT; + } + reduce_topk::reduceTopK( + group, topScores, topExperts, expertScoreGroup, expertIdxGroup, + invalidScoreT, static_cast(topk)); } - } } - if (sg_local_id == 0) { - int offset = tid * top_k; - for (int i = 0; i < top_k; ++i) { - topk_weights[offset + i] = - topk_weights_local[i] * routed_scaling_factor; - if (!(topk_ids_local[i] >= 0 && topk_ids_local[i] < experts)) { - // Ensure valid index - topk_ids[offset + i] = 0; - continue; + if (warpIdx == 0) { + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + T scoreNormT = laneIdx < topk ? smemScoreSigmoid[expertIdx] : sycl_cast(0.F); + float scoreNorm = sycl_cast(scoreNormT); + float finalScore = static_cast(scoreNorm * routedScalingFactor); + float topk_sum = 1e-20f; + if (renormalize) { + topk_sum += sycl::reduce_over_group(group, scoreNorm,sycl::plus()); + finalScore /= topk_sum; + } + if (laneIdx < topk) { + topkValues[laneIdx] = finalScore; + topkIndices[laneIdx] = expertIdx; } - topk_ids[offset + i] = topk_ids_local[i]; - } } - } - float* topk_weights; - int* topk_ids; - const T* gating_output; - const T* e_score_correction_bias; - const double routed_scaling_factor; - const ScoringFunc scoring_mode; - const bool renormalize; - const int tokens; - const int experts; - const int top_k; - const int num_expert_group; - const int topk_group; -}; - -template -void launch_fused_grouped_topk( - sycl::queue& queue, - float* topk_weights, - int* topk_ids, - const T* gating_output, - const T* e_score_correction_bias, - const double routed_scaling_factor, - const ScoringFunc scoring_mode, - const bool renormalize, - const int tokens, - const int experts, - const int top_k, - const int num_expert_group, - const int topk_group) { - using Kernel = Fused_Grouped_Topk; - auto range = Kernel::get_nd_range(tokens, experts); - - queue.submit([&](sycl::handler& cgh) { - Kernel task( - topk_weights, - topk_ids, - gating_output, - e_score_correction_bias, - routed_scaling_factor, - scoring_mode, - renormalize, - tokens, - experts, - top_k, - num_expert_group, - topk_group); - cgh.parallel_for(range, task); - }); + } // end if constexpr (!UseGroups) } -template -void fused_grouped_topk( - float* topk_weights, - int* topk_ids, - const T* gating_output, - const T* e_score_correction_bias, - const double routed_scaling_factor, - const ScoringFunc scoring_mode, - const bool renormalize, - const int tokens, - const int experts, - const int top_k, - const int num_expert_group, - const int topk_group) { - auto& queue = vllm::xpu::vllmGetQueue(); - - TORCH_CHECK( - topk_group <= num_expert_group, - "topk_group must be less than or equal to num_expert_group"); - TORCH_CHECK( - experts % num_expert_group == 0, - "The number of experts (experts=", - experts, - ") must be divisible by num_expert_group (", - num_expert_group, - ")."); - - int max_expert_group = ((num_expert_group + 7) / 8) * 8; -#define CASE_TOPK(K) \ - case K: \ - launch_fused_grouped_topk( \ - queue, \ - topk_weights, \ - topk_ids, \ - gating_output, \ - e_score_correction_bias, \ - routed_scaling_factor, \ - scoring_mode, \ - renormalize, \ - tokens, \ - experts, \ - top_k, \ - num_expert_group, \ - topk_group); \ - break; - switch (max_expert_group) { - CASE_TOPK(8) - CASE_TOPK(16) - default: - TORCH_CHECK( - false, "error: not support num_expert_group=%d,\n", num_expert_group); - } -#undef CASE_TOPK +template +void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices, + BiasT const* bias, int64_t const num_tokens, + int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, + bool enable_pdl = false, sycl::queue queue = sycl::queue()) { + int64_t experts_per_group = num_experts / n_group; + bool is_single_group = + (n_group == 1) && (topk_group == 1) && + (num_experts <= MaxSupportedExpertCount) && + (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); + + #define LAUNCH_SMALL_KERNEL(MAX_EXPERTS, USE_GROUPS, MAX_TOP_EXPERTS, NUM_THREADS) \ + do { \ + size_t local_size = static_cast(NUM_THREADS); \ + size_t global_size = static_cast(num_tokens) * local_size; \ + queue.submit([&](sycl::handler& cgh) { \ + cgh.parallel_for>( \ + sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), \ + [=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { \ + grouped_topk_fused_small_expert_count_kernel( \ + scores, topk_values, topk_indices, bias, \ + num_tokens, n_group, topk_group, topk, num_experts, \ + experts_per_group, renormalize, routed_scaling_factor, item); \ + }); \ + }); \ + } while (0) + + if (is_single_group) { + if (num_experts == NumNemotronExperts && n_group == 1 && + topk == MaxSupportedTopExperts) { + LAUNCH_SMALL_KERNEL(NumNemotronExperts, false, + MaxSupportedTopExperts, + ((NumNemotronExperts + MaxNumExpertsUnit - 1) / + MaxNumExpertsUnit) * WARP_SIZE); + } else if (num_experts > NumKimiK2Experts && + num_experts <= MaxSupportedExpertCount) { + LAUNCH_SMALL_KERNEL(MaxSupportedExpertCount, false, + DefaultMaxNumTopExperts, + ((MaxSupportedExpertCount + MaxNumExpertsUnit - 1) / + MaxNumExpertsUnit) * WARP_SIZE); + } else if (num_experts > MaxNumExpertsUnit && + num_experts <= NumKimiK2Experts) { + LAUNCH_SMALL_KERNEL(NumKimiK2Experts, false, + DefaultMaxNumTopExperts, + ((NumKimiK2Experts + MaxNumExpertsUnit - 1) / + MaxNumExpertsUnit) * WARP_SIZE); + } else { + LAUNCH_SMALL_KERNEL(MaxNumExpertsUnit, false, + DefaultMaxNumTopExperts, + WARP_SIZE); + } + } else { + LAUNCH_SMALL_KERNEL(NumDeepseekExperts, true, + DefaultMaxNumTopExperts, + WARP_SIZE); + } + + #undef LAUNCH_SMALL_KERNEL + } -}; // namespace GroupedTopKImpl +#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \ + template void invokeNoAuxTc( \ + T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \ + int64_t const num_tokens, int64_t const num_experts, \ + int64_t const n_group, int64_t const topk_group, int64_t const topk, \ + bool const renormalize, double const routed_scaling_factor, \ + bool enable_pdl, sycl::queue queue); + +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +} // end namespace moe } // namespace vllm -/** - * @brief Perform grouped topk after sigmoid/addbias on gating_output. - * @param gating_output The gating output tensor of shape [n_tokens, n_experts]. - * @param n_topk The number of top experts to select. - * @param n_topk_group The number of top experts to select in the group. - * @return A tuple of tensors (topk_weights, topk_indices). - */ std::tuple fused_grouped_topk( - const torch::Tensor& hidden_states, - const torch::Tensor& gating_output, - const int64_t n_topk, - const bool renormalize, - const int64_t n_expert_group, - const int64_t n_topk_group, - const c10::string_view scoring_func, - const double routed_scaling_factor, - const c10::optional& bias) { - auto shape = gating_output.sizes().vec(); - TORCH_CHECK( - hidden_states.sizes()[0] == gating_output.sizes()[0], - "Number of tokens mismatch") - TORCH_CHECK( - shape.size() == 2, - "gating_output must be 2D tensor, but got ", - shape.size(), - "D"); - if (bias.has_value()) { - auto shape_bias = bias->sizes().vec(); - TORCH_CHECK( - shape_bias[0] == shape[1], - "gating_output and bias must has same innermost dimension, but got ", - shape, - " and ", - shape_bias); - } - int n_tokens = shape[0]; - int n_experts = shape[1]; - - vllm::GroupedTopKImpl::ScoringFunc scoring_mode; - if (scoring_func == "sigmoid") { - scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SIGMOID; - } else if (scoring_func == "softmax") { - scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SOFTMAX; - } else { - scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::DEFAULT; - } - - auto topk_weights = - torch::empty({n_tokens, n_topk}, at::dtype(at::kFloat).device(at::kXPU)); - auto topk_indices = - torch::empty({n_tokens, n_topk}, at::dtype(at::kInt).device(at::kXPU)); - - if (gating_output.scalar_type() == at::kBFloat16) { - using scalar_t = sycl::ext::oneapi::bfloat16; - vllm::GroupedTopKImpl::fused_grouped_topk( - reinterpret_cast(topk_weights.data_ptr()), - reinterpret_cast(topk_indices.data_ptr()), - reinterpret_cast(gating_output.data_ptr()), - bias.has_value() ? reinterpret_cast(bias->data_ptr()) - : nullptr, - routed_scaling_factor, - scoring_mode, - renormalize, - n_tokens, - n_experts, - n_topk, - n_expert_group, - n_topk_group); - } else if (gating_output.scalar_type() == at::kHalf) { - using scalar_t = sycl::half; - vllm::GroupedTopKImpl::fused_grouped_topk( - reinterpret_cast(topk_weights.data_ptr()), - reinterpret_cast(topk_indices.data_ptr()), - reinterpret_cast(gating_output.data_ptr()), - bias.has_value() ? reinterpret_cast(bias->data_ptr()) - : nullptr, - routed_scaling_factor, - scoring_mode, - renormalize, - n_tokens, - n_experts, - n_topk, - n_expert_group, - n_topk_group); - } else { - using scalar_t = float; - vllm::GroupedTopKImpl::fused_grouped_topk( - reinterpret_cast(topk_weights.data_ptr()), - reinterpret_cast(topk_indices.data_ptr()), - reinterpret_cast(gating_output.data_ptr()), - bias.has_value() ? reinterpret_cast(bias->data_ptr()) - : nullptr, - routed_scaling_factor, - scoring_mode, - renormalize, - n_tokens, - n_experts, - n_topk, - n_expert_group, - n_topk_group); + torch::Tensor const& hidden_states, + torch::Tensor const& gating_output, + int64_t const n_topk, + bool const renormalize, + int64_t const n_expert_group, + int64_t const n_topk_group, + c10::string_view const scoring_func, + double const routed_scaling_factor, + c10::optional const& bias) { + auto data_type = gating_output.scalar_type(); + bool has_bias = bias.has_value() && bias->defined(); + auto bias_type = has_bias ? bias->scalar_type() : torch::kFloat32; + auto input_size = gating_output.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + int64_t n_group = n_expert_group; + int64_t topk_group = n_topk_group; + int64_t topk = n_topk; + + TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0], + "Number of tokens mismatch"); + TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor"); + TORCH_CHECK(n_group > 0, "n_group must be positive"); + TORCH_CHECK(topk > 0, "topk must be positive"); + TORCH_CHECK(topk_group > 0, "topk_group must be positive"); + TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group"); + TORCH_CHECK(num_experts % n_group == 0, + "num_experts should be divisible by n_group"); + TORCH_CHECK(n_group <= 32, + "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= topk_group * (num_experts / n_group), + "topk must be <= topk_group * (num_experts / n_group)"); + TORCH_CHECK(scoring_func == "sigmoid" || scoring_func == "softmax", + "Unsupported scoring_func: ", scoring_func); + auto const sf = (scoring_func == "sigmoid") + ? vllm::moe::SCORING_SIGMOID + : vllm::moe::SCORING_NONE; + + // Always output float32 for topk_values (eliminates Python-side conversion) + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kFloat32).device(gating_output.device())); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kInt32).device(gating_output.device())); + + auto device_idx = gating_output.device().index(); + auto stream = c10::xpu::getCurrentXPUStream(device_idx).queue(); + +#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \ + do { \ + switch (sf) { \ + case vllm::moe::SCORING_NONE: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(gating_output.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, false, stream); \ + break; \ + case vllm::moe::SCORING_SIGMOID: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(gating_output.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, false, stream); \ + break; \ + default: \ + throw std::invalid_argument("Unsupported scoring_func"); \ + break; \ + } \ + } while (0) + +#define LAUNCH_KERNEL(T, IdxT) \ + do{ \ + switch (bias_type) { \ + case torch::kFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::half, IdxT); \ + break; \ + case torch::kFloat32: \ + LAUNCH_KERNEL_SF(T, float, IdxT); \ + break; \ + case torch::kBFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::ext::oneapi::bfloat16, IdxT); \ + break; \ + default: \ + throw std::invalid_argument( \ + "Invalid bias dtype, only supports float16, float32, and " \ + "bfloat16"); \ + break; \ + } \ + } \ + while (0) + + + switch (data_type) { + case torch::kFloat16: + LAUNCH_KERNEL(sycl::half, int32_t); + break; + case torch::kFloat32: + LAUNCH_KERNEL(float, int32_t); + break; + case torch::kBFloat16: + LAUNCH_KERNEL(sycl::ext::oneapi::bfloat16, int32_t); + break; + default: + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; } - return std::make_tuple(topk_weights, topk_indices); +#undef LAUNCH_KERNEL +#undef LAUNCH_KERNEL_SF + return {topk_values, topk_indices}; } \ No newline at end of file diff --git a/csrc/moe/grouped_topk_kernels.cpp b/csrc/moe/grouped_topk_kernels.cpp deleted file mode 100644 index fef60a4af..000000000 --- a/csrc/moe/grouped_topk_kernels.cpp +++ /dev/null @@ -1,652 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu - * Copyright (c) 2025, The vLLM team. - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include -#include "../dispatch_utils.h" - -namespace vllm { -namespace moe { - -// Type trait: bfloat16 -> float for computation, everything else stays as-is -template -struct compute_type { using type = T; }; - -template <> -struct compute_type { using type = float; }; - -template -using compute_type_t = typename compute_type::type; - -constexpr unsigned FULL_WARP_MASK = 0xffffffff; -static constexpr int WARP_SIZE = 32; -static constexpr int NumNemotronExperts = 512; -static constexpr int NumKimiK2Experts = 384; -static constexpr int NumDeepseekExperts = 256; -static constexpr int MaxSupportedExpertCount = - std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); -static constexpr int MaxNumExpertsUnit = 128; -static constexpr int NumTopGroupScores = 2; -static constexpr int DefaultMaxNumTopExperts = 8; -static constexpr int MaxSupportedTopExperts = 22; -static constexpr int MaxNumTopGroups = 4; - -enum ScoringFunc : int { SCORING_NONE = 0, SCORING_SIGMOID = 1 }; - -template -class VllmGroupedTopKFusedKernel; - -template -class VllmGroupedTopKFusedSmallExpertCountKernel; - -template -inline T_OUT sycl_cast(T_IN val) { - return static_cast(val); -} - -template <> -inline float sycl_cast(sycl::half val) { - return static_cast(val); -} - -template <> -inline float sycl_cast(sycl::ext::oneapi::bfloat16 val) { - return static_cast(val); -} - -template -inline T neg_inf() { - return sycl_cast(-std::numeric_limits::infinity()); -} - -template -inline bool is_finite(const T val) { - return std::isfinite(sycl_cast(val)); -} - -inline float sigmoid_accurate(float x) { - return 1.f / (1.f + sycl::native::exp(-x)); // More efficient approximation Optimized point 1 -} - -template -inline T apply_sigmoid(T val) { - float f = sycl_cast(val); - return sycl_cast(sigmoid_accurate(f)); -} - -template -inline T apply_scoring(T val) { - if constexpr (SF == SCORING_NONE) { - return val; - } else if constexpr (SF == SCORING_SIGMOID) { - return apply_sigmoid(val); - } else { - static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID, - "Unsupported ScoringFunc in apply_scoring"); - return val; - } -} - -namespace reduce_topk { - -template -inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, - const T* in_vals, const IdxT* in_idxs, T min_val, - int topk) { - constexpr IdxT invalid_idx = std::numeric_limits::max(); - bool selected[N_IN] = {false}; - - for (int k = 0; k < topk; ++k) { - using CT = compute_type_t; - CT local_best_val = static_cast(min_val); - IdxT local_best_idx = invalid_idx; - int local_best_pos = -1; - - #pragma unroll - for (int i = 0; i < N_IN; ++i) { - if (selected[i]) { - continue; - } - T cand_val = in_vals[i]; - IdxT cand_idx = in_idxs[i]; - if ((cand_val > local_best_val) || - ((cand_val == local_best_val) && (cand_idx < local_best_idx))) { - local_best_val = cand_val; - local_best_idx = cand_idx; - local_best_pos = i; - } - } - - T warp_best_val = sycl::reduce_over_group( - subgroup, local_best_val, sycl::maximum()); - - IdxT warp_best_idx = invalid_idx; - if (local_best_pos != -1 && local_best_val == warp_best_val) { - warp_best_idx = local_best_idx; - } - warp_best_idx = sycl::reduce_over_group( - subgroup, warp_best_idx, sycl::minimum()); - - bool found = (warp_best_idx != invalid_idx); - if (found) { - int insert_pos = k; - while (insert_pos > 0 && out_val[insert_pos - 1] == warp_best_val && - out_idx[insert_pos - 1] > warp_best_idx) { - out_val[insert_pos] = out_val[insert_pos - 1]; - out_idx[insert_pos] = out_idx[insert_pos - 1]; - --insert_pos; - } - out_val[insert_pos] = warp_best_val; - out_idx[insert_pos] = warp_best_idx; - } else { - out_val[k] = min_val; - out_idx[k] = 0; - } - - if (found && local_best_pos != -1 && local_best_val == warp_best_val && - local_best_idx == warp_best_idx) { - selected[local_best_pos] = true; - } - } -} - -template -inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, - T val, IdxT idx, T min_val, int topk) { - T in_vals[1] = {val}; - IdxT in_idxs[1] = {idx}; - reduceTopK<1>(subgroup, out_val, out_idx, in_vals, in_idxs, min_val, - topk); -} - -} // namespace reduce_topk - -template -SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( - T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias, - int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, - int64_t const topk, int64_t const numExperts, - int64_t const numExpertsPerGroup, bool const renormalize, - double const routedScalingFactor, sycl::nd_item<1> item) { - - constexpr int NumWarps = MaxNumExperts / WARP_SIZE; - constexpr float invalidScoreFloat = -std::numeric_limits::infinity(); - - int threadIdx = item.get_local_id(0); - int blockIdx = item.get_group(0); - if constexpr (UseGroups){ - if (blockIdx >= numTokens) return; - } - int localSize = item.get_local_range(0); - bool has_bias = (routingBias != nullptr); - - int laneIdx = threadIdx % WARP_SIZE; - int warpIdx = threadIdx / WARP_SIZE; - - - topkValues += blockIdx * topk; - topkIndices += blockIdx * topk; - - if constexpr (UseGroups) { - auto subgroup = item.get_sub_group(); - T* scoresToken = scores + static_cast(blockIdx) * numExperts; - T selectedGroupScores[WARP_SIZE]; - int32_t selectedGroupIdx[WARP_SIZE]; - - T groupScore = neg_inf(); - if (laneIdx < numGroup) { - int32_t groupOffset = laneIdx * numExpertsPerGroup; - T largest = neg_inf(); - T secondLargest = neg_inf(); - - for (int32_t i = 0; i < numExpertsPerGroup; ++i) { - T value = apply_scoring(scoresToken[groupOffset + i]); - if (has_bias) { - value = value + sycl_cast(routingBias[groupOffset + i]); - } - if (value > largest) { - secondLargest = largest; - largest = value; - } else if (value > secondLargest) { - secondLargest = value; - } - } - groupScore = has_bias ? largest + secondLargest : largest; - } - - reduce_topk::reduceTopK( - subgroup, selectedGroupScores, selectedGroupIdx, - groupScore, laneIdx, neg_inf(), static_cast(topkGroup)); - - bool proceed = false; - if (topkGroup > 0) { - proceed = (selectedGroupScores[topkGroup - 1] != neg_inf()); - } - - if (!proceed) { - for (int i = laneIdx; i < topk; i += WARP_SIZE) { - topkIndices[i] = static_cast(i); - topkValues[i] = 1.0f / static_cast(topk); - } - return; - } - - constexpr int MaxExpertCandidatesPerLane = NumDeepseekExperts / WARP_SIZE; - T localCandidateScores[MaxExpertCandidatesPerLane]; - IdxT localCandidateIdx[MaxExpertCandidatesPerLane]; - T selectedExpertScores[DefaultMaxNumTopExperts]; - IdxT selectedExpertIdx[DefaultMaxNumTopExperts]; - - for (int i = 0; i < MaxExpertCandidatesPerLane; ++i) { - localCandidateScores[i] = neg_inf(); - localCandidateIdx[i] = 0; - } - - int32_t totalCandidates = topkGroup * numExpertsPerGroup; - for (int32_t candidate = laneIdx; candidate < totalCandidates; - candidate += WARP_SIZE) { - int32_t localSlot = candidate / WARP_SIZE; - int32_t selectedGroup = candidate / numExpertsPerGroup; - int32_t expertInGroup = candidate % numExpertsPerGroup; - int32_t gid = selectedGroupIdx[selectedGroup]; - int32_t idx = gid * numExpertsPerGroup + expertInGroup; - T candidateScore = neg_inf(); - - T input = scoresToken[idx]; - if (is_finite(input)) { - T score = apply_scoring(input); - candidateScore = score; - if (has_bias) { - candidateScore = candidateScore + sycl_cast(routingBias[idx]); - } - } - - localCandidateScores[localSlot] = candidateScore; - localCandidateIdx[localSlot] = static_cast(idx); - } - - reduce_topk::reduceTopK( - subgroup, selectedExpertScores, selectedExpertIdx, - localCandidateScores, localCandidateIdx, neg_inf(), static_cast(topk)); - - for (int i = 1; i < topk; ++i) { - T score = selectedExpertScores[i]; - IdxT idx = selectedExpertIdx[i]; - int j = i; - while (j > 0 && - ((selectedExpertScores[j - 1] < score) || - ((selectedExpertScores[j - 1] == score) && - (selectedExpertIdx[j - 1] > idx)))) { - selectedExpertScores[j] = selectedExpertScores[j - 1]; - selectedExpertIdx[j] = selectedExpertIdx[j - 1]; - --j; - } - selectedExpertScores[j] = score; - selectedExpertIdx[j] = idx; - } - - float laneUnbiased = 0.0f; - IdxT laneIdxOut = 0; - if (laneIdx < topk) { - laneIdxOut = selectedExpertIdx[laneIdx]; - T in = scoresToken[static_cast(laneIdxOut)]; - laneUnbiased = sycl_cast(apply_scoring(in)); - } - - float scale = static_cast(routedScalingFactor); - if (renormalize) { - float topkSum = 1e-20f; - topkSum += sycl::reduce_over_group( - subgroup, laneUnbiased,sycl::plus()); - scale /= topkSum; - } - - if (laneIdx < topk) { - topkIndices[laneIdx] = laneIdxOut; - topkValues[laneIdx] = laneUnbiased * scale; - } - return; - } else { - - float* smemScoreSigmoid = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); - float* smemScoreBias = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); - float topScores[MaxNumTopExperts] = {invalidScoreFloat}; - int32_t topExperts[MaxNumTopExperts] = {0}; - float expertScoreGroup[MaxNumTopGroups] = {invalidScoreFloat}; - int32_t expertIdxGroup[MaxNumTopGroups] = {0}; - auto group = item.get_sub_group(); - - for (int expert = threadIdx; expert < numExperts; expert += localSize) { - int64_t scoreIdx = int64_t{blockIdx} * int64_t{numExperts} + expert; - float score = sycl_cast(scores[scoreIdx]); - float scoreSigmoid = apply_scoring(score); - smemScoreSigmoid[expert] = scoreSigmoid; - smemScoreBias[expert] = has_bias - ? (scoreSigmoid + sycl_cast(routingBias[expert])) - : scoreSigmoid; - } - - if constexpr (MaxNumExperts > MaxNumExpertsUnit) { - constexpr int NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; - constexpr int NumInterTopK = NumExpertWarps * MaxNumTopExperts; - float* smemInterTopScores = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); - int32_t* smemInterTopExperts = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); - - if (warpIdx < NumExpertWarps) { - int offset = warpIdx * WARP_SIZE * MaxNumTopGroups; - - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - int expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] = (offset + expertIdx < numExperts) - ? smemScoreBias[offset + expertIdx] - : invalidScoreFloat; - } - reduce_topk::reduceTopK( - group, topScores, topExperts, expertScoreGroup, expertIdxGroup, - invalidScoreFloat, static_cast(topk)); - - if (laneIdx < MaxNumTopExperts) { - if (laneIdx < topk) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; - } else { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1; - } - } - } - item.barrier(sycl::access::fence_space::local_space); - if (warpIdx == 0) { - constexpr int NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; - float intermediateScore[NumInterTopKPerThread]; - int32_t intermediateExpert[NumInterTopKPerThread]; - - for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) { - int ii = i / WARP_SIZE; - if (i < NumInterTopK) { - intermediateScore[ii] = smemInterTopScores[i]; - intermediateExpert[ii] = smemInterTopExperts[i]; - } else { - intermediateScore[ii] = invalidScoreFloat; - intermediateExpert[ii] = MaxNumExperts - 1; - } - } - - reduce_topk::reduceTopK( - group, topScores, topExperts, intermediateScore, intermediateExpert, - invalidScoreFloat, static_cast(topk)); - } - } else { - if (warpIdx == 0) { - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - int32_t expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] = (expertIdx < numExperts) - ? smemScoreBias[expertIdx] - : invalidScoreFloat; - } - reduce_topk::reduceTopK( - group, topScores, topExperts, expertScoreGroup, expertIdxGroup, - invalidScoreFloat, static_cast(topk)); - } - } - - if (warpIdx == 0) { - int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; - float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; - float finalScore = static_cast(scoreNorm * routedScalingFactor); - float topk_sum = 1e-20f; - if (renormalize) { - topk_sum += sycl::reduce_over_group(group, scoreNorm,sycl::plus()); - finalScore /= topk_sum; - } - if (laneIdx < topk) { - topkIndices[laneIdx] = finalScore; - topkValues[laneIdx] = expertIdx; - } - } - } // end if constexpr (!UseGroups) -} - -template -void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices, - BiasT const* bias, int64_t const num_tokens, - int64_t const num_experts, int64_t const n_group, - int64_t const topk_group, int64_t const topk, - bool const renormalize, double const routed_scaling_factor, - bool enable_pdl = false, sycl::queue queue = sycl::queue()) { - int64_t experts_per_group = num_experts / n_group; - bool is_single_group = - (n_group == 1) && (topk_group == 1) && - (num_experts <= MaxSupportedExpertCount) && - (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); - - #define LAUNCH_SMALL_KERNEL(MAX_EXPERTS, USE_GROUPS, MAX_TOP_EXPERTS, NUM_THREADS) \ - do { \ - size_t local_size = static_cast(NUM_THREADS); \ - size_t global_size = static_cast(num_tokens) * local_size; \ - queue.submit([&](sycl::handler& cgh) { \ - cgh.parallel_for>( \ - sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), \ - [=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { \ - grouped_topk_fused_small_expert_count_kernel( \ - scores, topk_values, topk_indices, bias, \ - num_tokens, n_group, topk_group, topk, num_experts, \ - experts_per_group, renormalize, routed_scaling_factor, item); \ - }); \ - }); \ - } while (0) - - if (is_single_group) { - if (num_experts == NumNemotronExperts && n_group == 1 && - topk == MaxSupportedTopExperts) { - LAUNCH_SMALL_KERNEL(NumNemotronExperts, false, - MaxSupportedTopExperts, - ((NumNemotronExperts + MaxNumExpertsUnit - 1) / - MaxNumExpertsUnit) * WARP_SIZE); - } else if (num_experts > NumKimiK2Experts && - num_experts <= MaxSupportedExpertCount) { - LAUNCH_SMALL_KERNEL(MaxSupportedExpertCount, false, - DefaultMaxNumTopExperts, - ((MaxSupportedExpertCount + MaxNumExpertsUnit - 1) / - MaxNumExpertsUnit) * WARP_SIZE); - } else if (num_experts > MaxNumExpertsUnit && - num_experts <= NumKimiK2Experts) { - LAUNCH_SMALL_KERNEL(NumKimiK2Experts, false, - DefaultMaxNumTopExperts, - ((NumKimiK2Experts + MaxNumExpertsUnit - 1) / - MaxNumExpertsUnit) * WARP_SIZE); - } else { - LAUNCH_SMALL_KERNEL(MaxNumExpertsUnit, false, - DefaultMaxNumTopExperts, - WARP_SIZE); - } - } else { - LAUNCH_SMALL_KERNEL(NumDeepseekExperts, true, - DefaultMaxNumTopExperts, - WARP_SIZE); - } - - #undef LAUNCH_SMALL_KERNEL - -} - -#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \ - template void invokeNoAuxTc( \ - T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \ - int64_t const num_tokens, int64_t const num_experts, \ - int64_t const n_group, int64_t const topk_group, int64_t const topk, \ - bool const renormalize, double const routed_scaling_factor, \ - bool enable_pdl, sycl::queue queue); - -INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); -} // end namespace moe -} // namespace vllm - -std::tuple grouped_topk_multi_group( - torch::Tensor const& hidden_states, - torch::Tensor const& gating_output, - int64_t const n_topk, - bool const renormalize, - int64_t const n_expert_group, - int64_t const n_topk_group, - c10::string_view const scoring_func, - double const routed_scaling_factor, - c10::optional const& bias) { - auto data_type = gating_output.scalar_type(); - bool has_bias = bias.has_value() && bias->defined(); - auto bias_type = has_bias ? bias->scalar_type() : torch::kFloat32; - auto input_size = gating_output.sizes(); - int64_t num_tokens = input_size[0]; - int64_t num_experts = input_size[1]; - int64_t n_group = n_expert_group; - int64_t topk_group = n_topk_group; - int64_t topk = n_topk; - - TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0], - "Number of tokens mismatch"); - TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor"); - TORCH_CHECK(n_group > 0, "n_group must be positive"); - TORCH_CHECK(topk > 0, "topk must be positive"); - TORCH_CHECK(topk_group > 0, "topk_group must be positive"); - TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group"); - TORCH_CHECK(num_experts % n_group == 0, - "num_experts should be divisible by n_group"); - TORCH_CHECK(n_group <= 32, - "n_group should be smaller than or equal to 32 for now"); - TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); - TORCH_CHECK(topk <= topk_group * (num_experts / n_group), - "topk must be <= topk_group * (num_experts / n_group)"); - TORCH_CHECK(scoring_func == "sigmoid" || scoring_func == "softmax", - "Unsupported scoring_func: ", scoring_func); - // Pre-apply softmax on host side then use SCORING_NONE in kernel, - // because softmax requires a full-row reduction that the kernel doesn't support. - torch::Tensor scores; - vllm::moe::ScoringFunc sf; - if (scoring_func == "sigmoid") { - sf = vllm::moe::SCORING_SIGMOID; - scores = gating_output; - } else if (scoring_func == "softmax") { - sf = vllm::moe::SCORING_NONE; - scores = torch::softmax(gating_output, /*dim=*/-1); - } - - // Always output float32 for topk_values (eliminates Python-side conversion) - torch::Tensor topk_values = torch::empty( - {num_tokens, topk}, torch::dtype(torch::kFloat32).device(gating_output.device())); - torch::Tensor topk_indices = torch::empty( - {num_tokens, topk}, torch::dtype(torch::kInt32).device(gating_output.device())); - - auto device_idx = gating_output.device().index(); - auto stream = c10::xpu::getCurrentXPUStream(device_idx).queue(); - -#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \ - do { \ - switch (sf) { \ - case vllm::moe::SCORING_NONE: \ - vllm::moe::invokeNoAuxTc( \ - reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(topk_values.mutable_data_ptr()), \ - reinterpret_cast(topk_indices.mutable_data_ptr()), \ - (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ - num_experts, n_group, topk_group, topk, renormalize, \ - routed_scaling_factor, false, stream); \ - break; \ - case vllm::moe::SCORING_SIGMOID: \ - vllm::moe::invokeNoAuxTc( \ - reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(topk_values.mutable_data_ptr()), \ - reinterpret_cast(topk_indices.mutable_data_ptr()), \ - (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ - num_experts, n_group, topk_group, topk, renormalize, \ - routed_scaling_factor, false, stream); \ - break; \ - default: \ - throw std::invalid_argument("Unsupported scoring_func"); \ - break; \ - } \ - } while (0) - -#define LAUNCH_KERNEL(T, IdxT) \ - do{ \ - switch (bias_type) { \ - case torch::kFloat16: \ - LAUNCH_KERNEL_SF(T, sycl::half, IdxT); \ - break; \ - case torch::kFloat32: \ - LAUNCH_KERNEL_SF(T, float, IdxT); \ - break; \ - case torch::kBFloat16: \ - LAUNCH_KERNEL_SF(T, sycl::ext::oneapi::bfloat16, IdxT); \ - break; \ - default: \ - throw std::invalid_argument( \ - "Invalid bias dtype, only supports float16, float32, and " \ - "bfloat16"); \ - break; \ - } \ - } \ - while (0) - - - switch (data_type) { - case torch::kFloat16: - LAUNCH_KERNEL(sycl::half, int32_t); - break; - case torch::kFloat32: - LAUNCH_KERNEL(float, int32_t); - break; - case torch::kBFloat16: - LAUNCH_KERNEL(sycl::ext::oneapi::bfloat16, int32_t); - break; - default: - throw std::invalid_argument( - "Invalid dtype, only supports float16, float32, and bfloat16"); - break; - } -#undef LAUNCH_KERNEL -#undef LAUNCH_KERNEL_SF - return {topk_values, topk_indices}; -} \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 88f43762e..09cefdb53 100755 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -56,18 +56,6 @@ std::tuple fused_grouped_topk( const double routed_scaling_factor, const c10::optional& bias); -std::tuple grouped_topk_multi_group( - const torch::Tensor& hidden_states, - const torch::Tensor& gating_output, - const int64_t n_topk, - const bool renormalize, - const int64_t n_expert_group, - const int64_t n_topk_group, - const c10::string_view scoring_func, - const double routed_scaling_factor, - const c10::optional& bias); - - void topk_softmax( torch::Tensor& topk_weights, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 0f44d34ed..76145a877 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -57,21 +57,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Fused Grouped TopK m.def( - "fused_grouped_topk(Tensor hidden_states, Tensor gating_output, int " - "n_topk, " - "bool renormalize, int n_expert_group, int n_topk_group, str " - "scoring_func, float routed_scaling_factor, Tensor? bias=None) -> " - "(Tensor, Tensor)"); + "fused_grouped_topk(" + " Tensor hidden_states," + " Tensor gating_output," + " int n_topk," + " bool renormalize," + " int n_expert_group," + " int n_topk_group," + " str scoring_func," + " float routed_scaling_factor," + " Tensor? bias=None" + ") -> (Tensor, Tensor)"); m.impl("fused_grouped_topk", torch::kXPU, &fused_grouped_topk); - // Fused Grouped TopK (multi-group optimized path) - m.def( - "grouped_topk_multi_group(Tensor hidden_states, Tensor gating_output, int " - "n_topk, " - "bool renormalize, int n_expert_group, int n_topk_group, str " - "scoring_func, float routed_scaling_factor, Tensor? bias=None) -> " - "(Tensor, Tensor)"); - m.impl("grouped_topk_multi_group", torch::kXPU, &grouped_topk_multi_group); + // Grouped TopK Multi Group (from grouped_topk_kernels.cpp) + // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " diff --git a/tests/ops/grouped_topk_op.py b/tests/ops/grouped_topk_op.py index 6c31e1a4f..c6c91231e 100644 --- a/tests/ops/grouped_topk_op.py +++ b/tests/ops/grouped_topk_op.py @@ -113,25 +113,17 @@ def fused_grouped_topk_sycl( routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - return ops.fused_grouped_topk(hidden_states, gating_output, topk, + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + return ops.fused_grouped_topk(hidden_states, scores, topk, renormalize, num_expert_group, topk_group, scoring_func, routed_scaling_factor, e_score_correction_bias) -def grouped_topk_multi_group( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int, - topk_group: int, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - return ops.grouped_topk_multi_group(hidden_states, gating_output, topk, - renormalize, num_expert_group, - topk_group, scoring_func, - routed_scaling_factor, - e_score_correction_bias) diff --git a/tests/register_ops.py b/tests/register_ops.py index b6765efd1..0508ed06d 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -473,24 +473,6 @@ def fused_grouped_topk( routed_scaling_factor, e_score_correction_bias) -def grouped_topk_multi_group( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int, - topk_group: int, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, -): - return torch.ops._moe_C.grouped_topk_multi_group(hidden_states, gating_output, - topk, renormalize, - num_expert_group, topk_group, - scoring_func, - routed_scaling_factor, - e_score_correction_bias) - def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py index 7eb904ff4..83ad24e22 100644 --- a/tests/test_grouped_topk.py +++ b/tests/test_grouped_topk.py @@ -3,7 +3,7 @@ import torch from tests.ops.grouped_topk_op import (fused_grouped_topk, - fused_grouped_topk_sycl, grouped_topk,grouped_topk_multi_group) + fused_grouped_topk_sycl, grouped_topk) from tests.utils import seed_everything #override pytest parameters when enable mini pytest @@ -18,14 +18,14 @@ } -@pytest.mark.parametrize("n_token", [1, 33, 64, 50000,100000]) +@pytest.mark.parametrize("n_token", [64, 50000,100000]) @pytest.mark.parametrize("n_hidden", [1024, 2048]) -@pytest.mark.parametrize("n_expert", [256]) +@pytest.mark.parametrize("n_expert", [128, 256]) @pytest.mark.parametrize("topk", [8]) -@pytest.mark.parametrize("renormalize", [False,True]) +@pytest.mark.parametrize("renormalize", [True, False]) @pytest.mark.parametrize("num_expert_group", [8]) @pytest.mark.parametrize("topk_group", [4]) -@pytest.mark.parametrize("scoring_func", ["sigmoid",'softmax']) +@pytest.mark.parametrize("scoring_func", ["sigmoid","softmax"]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @@ -36,10 +36,10 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, seed_everything(0) hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="xpu") gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="xpu") + e_score_correction_bias = torch.randn((n_expert, ), dtype=dtype, device="xpu") - baseline_topk_weights, baseline_topk_ids = grouped_topk( hidden_states=hidden_states, gating_output=gating_output, @@ -50,7 +50,6 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) - test_topk_weights, test_topk_ids = fused_grouped_topk( hidden_states=hidden_states, gating_output=gating_output, @@ -72,40 +71,22 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) - test_topk_weights_multi_group, test_topk_ids_multi_group = grouped_topk_multi_group( - hidden_states=hidden_states, - gating_output=gating_output, - topk=topk, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias) if renormalize: - # torch.testing.assert_close(baseline_topk_weights, - # test_topk_weights, - # atol=2e-2, - # rtol=0) - # torch.testing.assert_close(baseline_topk_weights, - # test_topk_weights_sycl, - # atol=2e-2, - # rtol=0) torch.testing.assert_close(baseline_topk_weights, - test_topk_weights_multi_group, + test_topk_weights, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_topk_weights, + test_topk_weights_sycl, atol=2e-2, rtol=0) - # torch.testing.assert_close(baseline_topk_ids, - # test_topk_ids, - # atol=0, - # rtol=0) - # torch.testing.assert_close(baseline_topk_ids, - # test_topk_ids_sycl, - # atol=0, - # rtol=0) torch.testing.assert_close(baseline_topk_ids, - test_topk_ids_multi_group, - atol=0, - rtol=0) \ No newline at end of file + test_topk_ids, + atol=0, + rtol=0) + torch.testing.assert_close(baseline_topk_ids, + test_topk_ids_sycl, + atol=0, + rtol=0) From 1fb5daa9b045168361331cbb5a719cf9ced2cbff Mon Sep 17 00:00:00 2001 From: xiaolong-intel Date: Fri, 3 Apr 2026 16:45:38 +0800 Subject: [PATCH 03/15] Update fused_grouped_topk.cpp Signed-off-by: xiaolong-intel Signed-off-by: xiaolong Signed-off-by: root --- csrc/moe/fused_grouped_topk.cpp | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index 54705c41d..0b2613806 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -1,22 +1,3 @@ -/* - * Adapted from - * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu - * Copyright (c) 2025, The vLLM team. - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ #include #include #include @@ -645,4 +626,4 @@ std::tuple fused_grouped_topk( #undef LAUNCH_KERNEL #undef LAUNCH_KERNEL_SF return {topk_values, topk_indices}; -} \ No newline at end of file +} From 38b33d13c000e82d7730c8fde05c030b8961dba1 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 3 Apr 2026 09:15:03 +0000 Subject: [PATCH 04/15] import vllm_xpu_kernels._moe_C in fused_moe_interface.py Signed-off-by: root Signed-off-by: xiaolong Signed-off-by: root --- vllm_xpu_kernels/fused_moe_interface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 93735cb73..434529e4f 100755 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -6,6 +6,7 @@ try: from . import _C # noqa: F401 from . import _xpu_C # noqa: F401 + from . import _moe_C # noqa: F401 FUSEDMOE_UNAVAILABLE_REASON = None FUSEDMOE_AVAILABLE = True except ImportError as e: From 4df072b265c320a08d58fec19735d33ffc4a97f9 Mon Sep 17 00:00:00 2001 From: xiaolong-intel Date: Tue, 7 Apr 2026 10:09:14 +0800 Subject: [PATCH 05/15] Update torch_bindings.cpp Signed-off-by: xiaolong-intel Signed-off-by: xiaolong Signed-off-by: root --- csrc/moe/torch_bindings.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 76145a877..57015ca13 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -70,8 +70,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ") -> (Tensor, Tensor)"); m.impl("fused_grouped_topk", torch::kXPU, &fused_grouped_topk); - // Grouped TopK Multi Group (from grouped_topk_kernels.cpp) - // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " From 5b20f0c891b24f2eccafb2e0081157c6f054c828 Mon Sep 17 00:00:00 2001 From: xiaolong-intel Date: Tue, 7 Apr 2026 11:03:08 +0800 Subject: [PATCH 06/15] Update grouped_topk_op.py Signed-off-by: xiaolong-intel Signed-off-by: xiaolong Signed-off-by: root --- tests/ops/grouped_topk_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ops/grouped_topk_op.py b/tests/ops/grouped_topk_op.py index c6c91231e..b3b391860 100644 --- a/tests/ops/grouped_topk_op.py +++ b/tests/ops/grouped_topk_op.py @@ -21,7 +21,6 @@ def grouped_topk( ) -> tuple[torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" - # Move to CPU to avoid XPU OOM on intermediate tensors if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": From fb0b037bc13e3c30a11630db1b4e7b6647c75fec Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Apr 2026 09:04:44 +0000 Subject: [PATCH 07/15] Modify to make the code more standardized Signed-off-by: root Signed-off-by: xiaolong Signed-off-by: root --- benchmark/benchmark_grouped_topk.py | 12 +++--- csrc/moe/fused_grouped_topk.cpp | 56 +++++++++---------------- csrc/utils.h | 1 + vllm_xpu_kernels/fused_moe_interface.py | 1 - 4 files changed, 26 insertions(+), 44 deletions(-) diff --git a/benchmark/benchmark_grouped_topk.py b/benchmark/benchmark_grouped_topk.py index 28bb2b034..253c99bcf 100644 --- a/benchmark/benchmark_grouped_topk.py +++ b/benchmark/benchmark_grouped_topk.py @@ -75,10 +75,10 @@ def grouped_topk_compile( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -n_token_range = [1, 64, 256] -n_expert_range = [16, 64, 128] -topk_range = [2, 4] -topk_group_range = [4, 8] +n_token_range = [50000] +n_expert_range = [256] +topk_range = [8] +topk_group_range = [4] scoring_func_range = ["sigmoid", "softmax"] dtype_range = [torch.float16, torch.bfloat16, torch.float32] configs = list( @@ -121,7 +121,7 @@ def benchmark( ): n_hidden = 1024 routed_scaling_factor = 1.0 - num_expert_group = 8 + num_expert_group = 4 renormalize = True hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, @@ -209,4 +209,4 @@ def benchmark( args = parser.parse_args() benchmark = get_benchmark() - benchmark.run(print_data=True, save_path=args.save_path) + benchmark.run(print_data=True, save_path=args.save_path) \ No newline at end of file diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index 0b2613806..e069c9b6d 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -1,24 +1,13 @@ #include -#include -#include #include #include #include +#include #include "../dispatch_utils.h" - +#include "../utils.h" namespace vllm { namespace moe { -// Type trait: bfloat16 -> float for computation, everything else stays as-is -template -struct compute_type { using type = T; }; - -template <> -struct compute_type { using type = float; }; - -template -using compute_type_t = typename compute_type::type; - constexpr unsigned FULL_WARP_MASK = 0xffffffff; static constexpr int WARP_SIZE = 32; static constexpr int NumNemotronExperts = 512; @@ -47,35 +36,28 @@ inline T_OUT sycl_cast(T_IN val) { return static_cast(val); } - -template <> -inline float sycl_cast(sycl::half val) { - return static_cast(val); -} - -template <> -inline float sycl_cast(sycl::ext::oneapi::bfloat16 val) { - return static_cast(val); -} - template inline T neg_inf() { - return sycl_cast(-std::numeric_limits::infinity()); + T out; + xpu::from_float(out, -std::numeric_limits::infinity()); + return out; } template inline bool is_finite(const T val) { - return std::isfinite(sycl_cast(val)); + return std::isfinite(xpu::to_float(val)); } + inline float sigmoid_accurate(float x) { - return 1.f / (1.f + sycl::native::exp(-x)); // More efficient approximation Optimized point 1 + return 1.f / (1.f + sycl::native::exp(-x)); } template inline T apply_sigmoid(T val) { - float f = sycl_cast(val); - return sycl_cast(sigmoid_accurate(f)); - + float f = xpu::to_float(val); + T out; + xpu::from_float(out, sigmoid_accurate(f)); + return out; } template @@ -97,7 +79,7 @@ inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, bool selected[N_IN] = {false}; for (int k = 0; k < topk; ++k) { - using CT = compute_type_t; + using CT = xpu::acc_type; CT local_best_val = static_cast(min_val); IdxT local_best_idx = invalid_idx; int local_best_pos = -1; @@ -292,7 +274,7 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( if (laneIdx < topk) { laneIdxOut = selectedExpertIdx[laneIdx]; T in = scoresToken[static_cast(laneIdxOut)]; - laneUnbiased = sycl_cast(apply_scoring(in)); + laneUnbiased = xpu::to_float(apply_scoring(in)); } float scale = static_cast(routedScalingFactor); @@ -329,8 +311,6 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( : scoreSigmoid; } - // Barrier: ensure all warps have written smemScoreSigmoid/smemScoreBias - // before any warp reads them in the topk reduction below. item.barrier(sycl::access::fence_space::local_space); if constexpr (MaxNumExperts > MaxNumExpertsUnit) { @@ -402,8 +382,10 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( if (warpIdx == 0) { int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; - T scoreNormT = laneIdx < topk ? smemScoreSigmoid[expertIdx] : sycl_cast(0.F); - float scoreNorm = sycl_cast(scoreNormT); + T temp; + xpu::from_float(temp, 0.F); + T scoreNormT = laneIdx < topk ? smemScoreSigmoid[expertIdx] : temp; + float scoreNorm = xpu::to_float(scoreNormT); float finalScore = static_cast(scoreNorm * routedScalingFactor); float topk_sum = 1e-20f; if (renormalize) { @@ -550,7 +532,7 @@ std::tuple fused_grouped_topk( ? vllm::moe::SCORING_SIGMOID : vllm::moe::SCORING_NONE; - // Always output float32 for topk_values (eliminates Python-side conversion) + torch::Tensor topk_values = torch::empty( {num_tokens, topk}, torch::dtype(torch::kFloat32).device(gating_output.device())); torch::Tensor topk_indices = torch::empty( diff --git a/csrc/utils.h b/csrc/utils.h index 0b24dfba4..6f217b359 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -107,6 +107,7 @@ struct SyclTypeTrait { using Type = sycl::ext::oneapi::bfloat16; }; + template struct AccumulateType { private: diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 434529e4f..93735cb73 100755 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -6,7 +6,6 @@ try: from . import _C # noqa: F401 from . import _xpu_C # noqa: F401 - from . import _moe_C # noqa: F401 FUSEDMOE_UNAVAILABLE_REASON = None FUSEDMOE_AVAILABLE = True except ImportError as e: From 81ef579033aeff3a9b2757cf2ed730139d4d4289 Mon Sep 17 00:00:00 2001 From: xiaolong-intel Date: Thu, 9 Apr 2026 16:37:01 +0800 Subject: [PATCH 08/15] Update utils.h Signed-off-by: xiaolong-intel Signed-off-by: xiaolong Signed-off-by: root --- csrc/utils.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/utils.h b/csrc/utils.h index 6f217b359..b9efb1a79 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -114,7 +114,9 @@ struct AccumulateType { static constexpr bool is_narrow_float = std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v; + std::is_same_v || + std::is_same_v || + std::is_same_v;; static constexpr bool is_integer = std::is_same_v || std::is_same_v || From 1f6c09194c509fb08baaddfcab9490d81a37b01c Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Apr 2026 05:35:32 +0000 Subject: [PATCH 09/15] Removed unnecessary modifications and made changes to the precision conversion of renormalization Signed-off-by: root Signed-off-by: xiaolong Signed-off-by: root --- benchmark/benchmark_grouped_topk.py | 10 +++--- csrc/moe/fused_grouped_topk.cpp | 48 ++++++++++++++++++++--------- csrc/utils.h | 4 +-- tests/test_grouped_topk.py | 15 ++++----- 4 files changed, 47 insertions(+), 30 deletions(-) diff --git a/benchmark/benchmark_grouped_topk.py b/benchmark/benchmark_grouped_topk.py index 253c99bcf..ceb6530cf 100644 --- a/benchmark/benchmark_grouped_topk.py +++ b/benchmark/benchmark_grouped_topk.py @@ -75,10 +75,10 @@ def grouped_topk_compile( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -n_token_range = [50000] -n_expert_range = [256] -topk_range = [8] -topk_group_range = [4] +n_token_range = [1, 64, 256] +n_expert_range = [16, 64, 128] +topk_range = [2, 4] +topk_group_range = [4, 8] scoring_func_range = ["sigmoid", "softmax"] dtype_range = [torch.float16, torch.bfloat16, torch.float32] configs = list( @@ -121,7 +121,7 @@ def benchmark( ): n_hidden = 1024 routed_scaling_factor = 1.0 - num_expert_group = 4 + num_expert_group = 8 renormalize = True hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index e069c9b6d..1b9604555 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -20,6 +20,7 @@ static constexpr int NumTopGroupScores = 2; static constexpr int DefaultMaxNumTopExperts = 8; static constexpr int MaxSupportedTopExperts = 22; static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxReduceTopK = 32; enum ScoringFunc : int { SCORING_NONE = 0, SCORING_SIGMOID = 1 }; @@ -79,8 +80,7 @@ inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, bool selected[N_IN] = {false}; for (int k = 0; k < topk; ++k) { - using CT = xpu::acc_type; - CT local_best_val = static_cast(min_val); + T local_best_val = min_val; IdxT local_best_idx = invalid_idx; int local_best_pos = -1; @@ -98,11 +98,13 @@ inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, local_best_pos = i; } } - - T warp_best_val = sycl::reduce_over_group( - subgroup, local_best_val, sycl::maximum()); - + float local_best_val_tmp = xpu::to_float(local_best_val); + float warp_best_val_tmp = sycl::reduce_over_group( + subgroup, local_best_val_tmp, sycl::maximum()); + + T warp_best_val = static_cast(warp_best_val_tmp); IdxT warp_best_idx = invalid_idx; + if (local_best_pos != -1 && local_best_val == warp_best_val) { warp_best_idx = local_best_idx; } @@ -112,11 +114,20 @@ inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, bool found = (warp_best_idx != invalid_idx); if (found) { int insert_pos = k; - while (insert_pos > 0 && out_val[insert_pos - 1] == warp_best_val && - out_idx[insert_pos - 1] > warp_best_idx) { - out_val[insert_pos] = out_val[insert_pos - 1]; - out_idx[insert_pos] = out_idx[insert_pos - 1]; - --insert_pos; + bool still_shifting = true; + #pragma unroll + for (int shift = 0; shift < MaxReduceTopK - 1; ++shift) { + int prev_pos = k - shift - 1; + bool active = shift < k; + bool should_shift = active && still_shifting && + out_val[prev_pos] == warp_best_val && + out_idx[prev_pos] > warp_best_idx; + if (should_shift) { + out_val[prev_pos + 1] = out_val[prev_pos]; + out_idx[prev_pos + 1] = out_idx[prev_pos]; + insert_pos = prev_pos; + } + still_shifting = still_shifting && should_shift; } out_val[insert_pos] = warp_best_val; out_idx[insert_pos] = warp_best_idx; @@ -279,10 +290,17 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( float scale = static_cast(routedScalingFactor); if (renormalize) { - float topkSum = 1e-20f; - topkSum += sycl::reduce_over_group( - subgroup, laneUnbiased,sycl::plus()); - scale /= topkSum; + // Match baseline precision: sum and divide in T precision + T laneScoreT = static_cast(laneUnbiased); + T topkSumT = static_cast(0); + for (int i = 0; i < static_cast(topk); ++i) { + T val = sycl::select_from_group(subgroup, laneScoreT, i); + topkSumT = topkSumT + val; + } + if (laneIdx < topk) { + T normalizedT = laneScoreT / topkSumT; + laneUnbiased = xpu::to_float(normalizedT); + } } if (laneIdx < topk) { diff --git a/csrc/utils.h b/csrc/utils.h index b9efb1a79..6f217b359 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -114,9 +114,7 @@ struct AccumulateType { static constexpr bool is_narrow_float = std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v;; + std::is_same_v; static constexpr bool is_integer = std::is_same_v || std::is_same_v || diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py index 83ad24e22..82968dcb9 100644 --- a/tests/test_grouped_topk.py +++ b/tests/test_grouped_topk.py @@ -18,14 +18,14 @@ } -@pytest.mark.parametrize("n_token", [64, 50000,100000]) +@pytest.mark.parametrize("n_token", [1, 33, 64]) @pytest.mark.parametrize("n_hidden", [1024, 2048]) -@pytest.mark.parametrize("n_expert", [128, 256]) -@pytest.mark.parametrize("topk", [8]) +@pytest.mark.parametrize("n_expert", [16]) +@pytest.mark.parametrize("topk", [2]) @pytest.mark.parametrize("renormalize", [True, False]) @pytest.mark.parametrize("num_expert_group", [8]) -@pytest.mark.parametrize("topk_group", [4]) -@pytest.mark.parametrize("scoring_func", ["sigmoid","softmax"]) +@pytest.mark.parametrize("topk_group", [2]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @@ -36,10 +36,10 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, seed_everything(0) hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="xpu") gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="xpu") - e_score_correction_bias = torch.randn((n_expert, ), dtype=dtype, device="xpu") + baseline_topk_weights, baseline_topk_ids = grouped_topk( hidden_states=hidden_states, gating_output=gating_output, @@ -50,6 +50,7 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) + test_topk_weights, test_topk_ids = fused_grouped_topk( hidden_states=hidden_states, gating_output=gating_output, @@ -89,4 +90,4 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, torch.testing.assert_close(baseline_topk_ids, test_topk_ids_sycl, atol=0, - rtol=0) + rtol=0) \ No newline at end of file From be52c90fcb2b3d25a310fc06934c33361c812c55 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Apr 2026 06:01:08 +0000 Subject: [PATCH 10/15] pre-commit check Signed-off-by: root Signed-off-by: Signed-off-by: xiaolong Signed-off-by: root --- csrc/moe/fused_grouped_topk.cpp | 1059 ++++++++++++++++++------------- csrc/moe/moe_ops.h | 1 - csrc/moe/torch_bindings.cpp | 1 - csrc/utils.h | 1 - tests/ops/grouped_topk_op.py | 71 ++- 5 files changed, 644 insertions(+), 489 deletions(-) mode change 100755 => 100644 csrc/moe/moe_ops.h diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index 1b9604555..ea619c771 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -27,291 +27,330 @@ enum ScoringFunc : int { SCORING_NONE = 0, SCORING_SIGMOID = 1 }; template class VllmGroupedTopKFusedKernel; -template +template < + typename T, + typename BiasT, + typename IdxT, + ScoringFunc SF, + int MaxNumExperts, + bool UseGroups, + int MaxNumTopExperts = DefaultMaxNumTopExperts> class VllmGroupedTopKFusedSmallExpertCountKernel; template inline T_OUT sycl_cast(T_IN val) { - return static_cast(val); + return static_cast(val); } template inline T neg_inf() { - T out; - xpu::from_float(out, -std::numeric_limits::infinity()); - return out; + T out; + xpu::from_float(out, -std::numeric_limits::infinity()); + return out; } template inline bool is_finite(const T val) { - return std::isfinite(xpu::to_float(val)); + return std::isfinite(xpu::to_float(val)); } inline float sigmoid_accurate(float x) { - return 1.f / (1.f + sycl::native::exp(-x)); + return 1.f / (1.f + sycl::native::exp(-x)); } template inline T apply_sigmoid(T val) { - float f = xpu::to_float(val); - T out; - xpu::from_float(out, sigmoid_accurate(f)); - return out; + float f = xpu::to_float(val); + T out; + xpu::from_float(out, sigmoid_accurate(f)); + return out; } template inline T apply_scoring(T val) { - if constexpr (SF == SCORING_NONE) { - return val; - } else if constexpr (SF == SCORING_SIGMOID) { - return apply_sigmoid(val); - } + if constexpr (SF == SCORING_NONE) { + return val; + } else if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } } namespace reduce_topk { template -inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, - const T* in_vals, const IdxT* in_idxs, T min_val, - int topk) { - constexpr IdxT invalid_idx = std::numeric_limits::max(); - bool selected[N_IN] = {false}; - - for (int k = 0; k < topk; ++k) { - T local_best_val = min_val; - IdxT local_best_idx = invalid_idx; - int local_best_pos = -1; - - #pragma unroll - for (int i = 0; i < N_IN; ++i) { - if (selected[i]) { - continue; - } - T cand_val = in_vals[i]; - IdxT cand_idx = in_idxs[i]; - if ((cand_val > local_best_val) || - ((cand_val == local_best_val) && (cand_idx < local_best_idx))) { - local_best_val = cand_val; - local_best_idx = cand_idx; - local_best_pos = i; - } - } - float local_best_val_tmp = xpu::to_float(local_best_val); - float warp_best_val_tmp = sycl::reduce_over_group( - subgroup, local_best_val_tmp, sycl::maximum()); - - T warp_best_val = static_cast(warp_best_val_tmp); - IdxT warp_best_idx = invalid_idx; - - if (local_best_pos != -1 && local_best_val == warp_best_val) { - warp_best_idx = local_best_idx; - } - warp_best_idx = sycl::reduce_over_group( - subgroup, warp_best_idx, sycl::minimum()); - - bool found = (warp_best_idx != invalid_idx); - if (found) { - int insert_pos = k; - bool still_shifting = true; - #pragma unroll - for (int shift = 0; shift < MaxReduceTopK - 1; ++shift) { - int prev_pos = k - shift - 1; - bool active = shift < k; - bool should_shift = active && still_shifting && - out_val[prev_pos] == warp_best_val && - out_idx[prev_pos] > warp_best_idx; - if (should_shift) { - out_val[prev_pos + 1] = out_val[prev_pos]; - out_idx[prev_pos + 1] = out_idx[prev_pos]; - insert_pos = prev_pos; - } - still_shifting = still_shifting && should_shift; - } - out_val[insert_pos] = warp_best_val; - out_idx[insert_pos] = warp_best_idx; - } else { - out_val[k] = min_val; - out_idx[k] = 0; - } +inline void reduceTopK( + sycl::sub_group subgroup, + T* out_val, + IdxT* out_idx, + const T* in_vals, + const IdxT* in_idxs, + T min_val, + int topk) { + constexpr IdxT invalid_idx = std::numeric_limits::max(); + bool selected[N_IN] = {false}; + + for (int k = 0; k < topk; ++k) { + T local_best_val = min_val; + IdxT local_best_idx = invalid_idx; + int local_best_pos = -1; + +#pragma unroll + for (int i = 0; i < N_IN; ++i) { + if (selected[i]) { + continue; + } + T cand_val = in_vals[i]; + IdxT cand_idx = in_idxs[i]; + if ((cand_val > local_best_val) || + ((cand_val == local_best_val) && (cand_idx < local_best_idx))) { + local_best_val = cand_val; + local_best_idx = cand_idx; + local_best_pos = i; + } + } + float local_best_val_tmp = xpu::to_float(local_best_val); + float warp_best_val_tmp = sycl::reduce_over_group( + subgroup, local_best_val_tmp, sycl::maximum()); + + T warp_best_val = static_cast(warp_best_val_tmp); + IdxT warp_best_idx = invalid_idx; - if (found && local_best_pos != -1 && local_best_val == warp_best_val && - local_best_idx == warp_best_idx) { - selected[local_best_pos] = true; + if (local_best_pos != -1 && local_best_val == warp_best_val) { + warp_best_idx = local_best_idx; + } + warp_best_idx = + sycl::reduce_over_group(subgroup, warp_best_idx, sycl::minimum()); + + bool found = (warp_best_idx != invalid_idx); + if (found) { + int insert_pos = k; + bool still_shifting = true; +#pragma unroll + for (int shift = 0; shift < MaxReduceTopK - 1; ++shift) { + int prev_pos = k - shift - 1; + bool active = shift < k; + bool should_shift = active && still_shifting && + out_val[prev_pos] == warp_best_val && + out_idx[prev_pos] > warp_best_idx; + if (should_shift) { + out_val[prev_pos + 1] = out_val[prev_pos]; + out_idx[prev_pos + 1] = out_idx[prev_pos]; + insert_pos = prev_pos; } + still_shifting = still_shifting && should_shift; + } + out_val[insert_pos] = warp_best_val; + out_idx[insert_pos] = warp_best_idx; + } else { + out_val[k] = min_val; + out_idx[k] = 0; + } + + if (found && local_best_pos != -1 && local_best_val == warp_best_val && + local_best_idx == warp_best_idx) { + selected[local_best_pos] = true; } + } } template -inline void reduceTopK(sycl::sub_group subgroup, T* out_val, IdxT* out_idx, - T val, IdxT idx, T min_val, int topk) { - T in_vals[1] = {val}; - IdxT in_idxs[1] = {idx}; - reduceTopK<1>(subgroup, out_val, out_idx, in_vals, in_idxs, min_val, - topk); +inline void reduceTopK( + sycl::sub_group subgroup, + T* out_val, + IdxT* out_idx, + T val, + IdxT idx, + T min_val, + int topk) { + T in_vals[1] = {val}; + IdxT in_idxs[1] = {idx}; + reduceTopK<1>(subgroup, out_val, out_idx, in_vals, in_idxs, min_val, topk); } } // namespace reduce_topk -template +template < + typename T, + typename BiasT, + typename IdxT, + ScoringFunc SF, + int MaxNumExperts, + bool UseGroups, + int MaxNumTopExperts = DefaultMaxNumTopExperts> SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( - T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias, - int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, - int64_t const topk, int64_t const numExperts, - int64_t const numExpertsPerGroup, bool const renormalize, - double const routedScalingFactor, sycl::nd_item<1> item) { - - constexpr int NumWarps = MaxNumExperts / WARP_SIZE; - constexpr float invalidScoreFloat = -std::numeric_limits::infinity(); - - int threadIdx = item.get_local_id(0); - int blockIdx = item.get_group(0); - if constexpr (UseGroups){ - if (blockIdx >= numTokens) return; - } - int localSize = item.get_local_range(0); - bool has_bias = (routingBias != nullptr); - - int laneIdx = threadIdx % WARP_SIZE; - int warpIdx = threadIdx / WARP_SIZE; - - - topkValues += blockIdx * topk; - topkIndices += blockIdx * topk; - - if constexpr (UseGroups) { - auto subgroup = item.get_sub_group(); - T* scoresToken = scores + static_cast(blockIdx) * numExperts; - T selectedGroupScores[WARP_SIZE]; - int32_t selectedGroupIdx[WARP_SIZE]; - - T groupScore = neg_inf(); - if (laneIdx < numGroup) { - int32_t groupOffset = laneIdx * numExpertsPerGroup; - T largest = neg_inf(); - T secondLargest = neg_inf(); - - for (int32_t i = 0; i < numExpertsPerGroup; ++i) { - T value = apply_scoring(scoresToken[groupOffset + i]); - if (has_bias) { - value = value + sycl_cast(routingBias[groupOffset + i]); - } - if (value > largest) { - secondLargest = largest; - largest = value; - } else if (value > secondLargest) { - secondLargest = value; - } - } - groupScore = has_bias ? largest + secondLargest : largest; + T* scores, + float* topkValues, + IdxT* topkIndices, + BiasT const* routingBias, + int64_t const numTokens, + int64_t const numGroup, + int64_t const topkGroup, + int64_t const topk, + int64_t const numExperts, + int64_t const numExpertsPerGroup, + bool const renormalize, + double const routedScalingFactor, + sycl::nd_item<1> item) { + constexpr int NumWarps = MaxNumExperts / WARP_SIZE; + constexpr float invalidScoreFloat = -std::numeric_limits::infinity(); + + int threadIdx = item.get_local_id(0); + int blockIdx = item.get_group(0); + if constexpr (UseGroups) { + if (blockIdx >= numTokens) return; + } + int localSize = item.get_local_range(0); + bool has_bias = (routingBias != nullptr); + + int laneIdx = threadIdx % WARP_SIZE; + int warpIdx = threadIdx / WARP_SIZE; + + topkValues += blockIdx * topk; + topkIndices += blockIdx * topk; + + if constexpr (UseGroups) { + auto subgroup = item.get_sub_group(); + T* scoresToken = scores + static_cast(blockIdx) * numExperts; + T selectedGroupScores[WARP_SIZE]; + int32_t selectedGroupIdx[WARP_SIZE]; + + T groupScore = neg_inf(); + if (laneIdx < numGroup) { + int32_t groupOffset = laneIdx * numExpertsPerGroup; + T largest = neg_inf(); + T secondLargest = neg_inf(); + + for (int32_t i = 0; i < numExpertsPerGroup; ++i) { + T value = apply_scoring(scoresToken[groupOffset + i]); + if (has_bias) { + value = value + sycl_cast(routingBias[groupOffset + i]); } - - reduce_topk::reduceTopK( - subgroup, selectedGroupScores, selectedGroupIdx, - groupScore, laneIdx, neg_inf(), static_cast(topkGroup)); - - bool proceed = false; - if (topkGroup > 0) { - proceed = (selectedGroupScores[topkGroup - 1] != neg_inf()); + if (value > largest) { + secondLargest = largest; + largest = value; + } else if (value > secondLargest) { + secondLargest = value; } + } + groupScore = has_bias ? largest + secondLargest : largest; + } - if (!proceed) { - for (int i = laneIdx; i < topk; i += WARP_SIZE) { - topkIndices[i] = static_cast(i); - topkValues[i] = 1.0f / static_cast(topk); - } - return; - } + reduce_topk::reduceTopK( + subgroup, + selectedGroupScores, + selectedGroupIdx, + groupScore, + laneIdx, + neg_inf(), + static_cast(topkGroup)); + + bool proceed = false; + if (topkGroup > 0) { + proceed = (selectedGroupScores[topkGroup - 1] != neg_inf()); + } - constexpr int MaxExpertCandidatesPerLane = NumDeepseekExperts / WARP_SIZE; - T localCandidateScores[MaxExpertCandidatesPerLane]; - IdxT localCandidateIdx[MaxExpertCandidatesPerLane]; - T selectedExpertScores[DefaultMaxNumTopExperts]; - IdxT selectedExpertIdx[DefaultMaxNumTopExperts]; + if (!proceed) { + for (int i = laneIdx; i < topk; i += WARP_SIZE) { + topkIndices[i] = static_cast(i); + topkValues[i] = 1.0f / static_cast(topk); + } + return; + } - for (int i = 0; i < MaxExpertCandidatesPerLane; ++i) { - localCandidateScores[i] = neg_inf(); - localCandidateIdx[i] = 0; - } + constexpr int MaxExpertCandidatesPerLane = NumDeepseekExperts / WARP_SIZE; + T localCandidateScores[MaxExpertCandidatesPerLane]; + IdxT localCandidateIdx[MaxExpertCandidatesPerLane]; + T selectedExpertScores[DefaultMaxNumTopExperts]; + IdxT selectedExpertIdx[DefaultMaxNumTopExperts]; - int32_t totalCandidates = topkGroup * numExpertsPerGroup; - for (int32_t candidate = laneIdx; candidate < totalCandidates; - candidate += WARP_SIZE) { - int32_t localSlot = candidate / WARP_SIZE; - int32_t selectedGroup = candidate / numExpertsPerGroup; - int32_t expertInGroup = candidate % numExpertsPerGroup; - int32_t gid = selectedGroupIdx[selectedGroup]; - int32_t idx = gid * numExpertsPerGroup + expertInGroup; - T candidateScore = neg_inf(); - - T input = scoresToken[idx]; - if (is_finite(input)) { - T score = apply_scoring(input); - candidateScore = score; - if (has_bias) { - candidateScore = candidateScore + sycl_cast(routingBias[idx]); - } - } - - localCandidateScores[localSlot] = candidateScore; - localCandidateIdx[localSlot] = static_cast(idx); - } + for (int i = 0; i < MaxExpertCandidatesPerLane; ++i) { + localCandidateScores[i] = neg_inf(); + localCandidateIdx[i] = 0; + } - reduce_topk::reduceTopK( - subgroup, selectedExpertScores, selectedExpertIdx, - localCandidateScores, localCandidateIdx, neg_inf(), static_cast(topk)); - - for (int i = 1; i < topk; ++i) { - T score = selectedExpertScores[i]; - IdxT idx = selectedExpertIdx[i]; - int j = i; - while (j > 0 && - ((selectedExpertScores[j - 1] < score) || - ((selectedExpertScores[j - 1] == score) && - (selectedExpertIdx[j - 1] > idx)))) { - selectedExpertScores[j] = selectedExpertScores[j - 1]; - selectedExpertIdx[j] = selectedExpertIdx[j - 1]; - --j; - } - selectedExpertScores[j] = score; - selectedExpertIdx[j] = idx; + int32_t totalCandidates = topkGroup * numExpertsPerGroup; + for (int32_t candidate = laneIdx; candidate < totalCandidates; + candidate += WARP_SIZE) { + int32_t localSlot = candidate / WARP_SIZE; + int32_t selectedGroup = candidate / numExpertsPerGroup; + int32_t expertInGroup = candidate % numExpertsPerGroup; + int32_t gid = selectedGroupIdx[selectedGroup]; + int32_t idx = gid * numExpertsPerGroup + expertInGroup; + T candidateScore = neg_inf(); + + T input = scoresToken[idx]; + if (is_finite(input)) { + T score = apply_scoring(input); + candidateScore = score; + if (has_bias) { + candidateScore = + candidateScore + sycl_cast(routingBias[idx]); } + } - float laneUnbiased = 0.0f; - IdxT laneIdxOut = 0; - if (laneIdx < topk) { - laneIdxOut = selectedExpertIdx[laneIdx]; - T in = scoresToken[static_cast(laneIdxOut)]; - laneUnbiased = xpu::to_float(apply_scoring(in)); - } + localCandidateScores[localSlot] = candidateScore; + localCandidateIdx[localSlot] = static_cast(idx); + } - float scale = static_cast(routedScalingFactor); - if (renormalize) { - // Match baseline precision: sum and divide in T precision - T laneScoreT = static_cast(laneUnbiased); - T topkSumT = static_cast(0); - for (int i = 0; i < static_cast(topk); ++i) { - T val = sycl::select_from_group(subgroup, laneScoreT, i); - topkSumT = topkSumT + val; - } - if (laneIdx < topk) { - T normalizedT = laneScoreT / topkSumT; - laneUnbiased = xpu::to_float(normalizedT); - } - } + reduce_topk::reduceTopK( + subgroup, + selectedExpertScores, + selectedExpertIdx, + localCandidateScores, + localCandidateIdx, + neg_inf(), + static_cast(topk)); + + for (int i = 1; i < topk; ++i) { + T score = selectedExpertScores[i]; + IdxT idx = selectedExpertIdx[i]; + int j = i; + while (j > 0 && ((selectedExpertScores[j - 1] < score) || + ((selectedExpertScores[j - 1] == score) && + (selectedExpertIdx[j - 1] > idx)))) { + selectedExpertScores[j] = selectedExpertScores[j - 1]; + selectedExpertIdx[j] = selectedExpertIdx[j - 1]; + --j; + } + selectedExpertScores[j] = score; + selectedExpertIdx[j] = idx; + } - if (laneIdx < topk) { - topkIndices[laneIdx] = laneIdxOut; - topkValues[laneIdx] = laneUnbiased * scale; - } - return; - } else { + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = selectedExpertIdx[laneIdx]; + T in = scoresToken[static_cast(laneIdxOut)]; + laneUnbiased = xpu::to_float(apply_scoring(in)); + } + + float scale = static_cast(routedScalingFactor); + if (renormalize) { + // Match baseline precision: sum and divide in T precision + T laneScoreT = static_cast(laneUnbiased); + T topkSumT = static_cast(0); + for (int i = 0; i < static_cast(topk); ++i) { + T val = sycl::select_from_group(subgroup, laneScoreT, i); + topkSumT = topkSumT + val; + } + if (laneIdx < topk) { + T normalizedT = laneScoreT / topkSumT; + laneUnbiased = xpu::to_float(normalizedT); + } + } - T* smemScoreSigmoid = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); - T* smemScoreBias = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); + if (laneIdx < topk) { + topkIndices[laneIdx] = laneIdxOut; + topkValues[laneIdx] = laneUnbiased * scale; + } + return; + } else { + T* smemScoreSigmoid = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + item.get_group()); + T* smemScoreBias = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + item.get_group()); T invalidScoreT = neg_inf(); T topScores[MaxNumTopExperts] = {neg_inf()}; int32_t topExperts[MaxNumTopExperts] = {0}; @@ -320,193 +359,283 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( auto group = item.get_sub_group(); for (int expert = threadIdx; expert < numExperts; expert += localSize) { - int64_t scoreIdx = int64_t{blockIdx} * int64_t{numExperts} + expert; - T score = scores[scoreIdx]; - T scoreSigmoid = apply_scoring(score); - smemScoreSigmoid[expert] = scoreSigmoid; - smemScoreBias[expert] = has_bias - ? (scoreSigmoid + sycl_cast(routingBias[expert])) - : scoreSigmoid; + int64_t scoreIdx = int64_t{blockIdx} * int64_t{numExperts} + expert; + T score = scores[scoreIdx]; + T scoreSigmoid = apply_scoring(score); + smemScoreSigmoid[expert] = scoreSigmoid; + smemScoreBias[expert] = + has_bias ? (scoreSigmoid + sycl_cast(routingBias[expert])) + : scoreSigmoid; } item.barrier(sycl::access::fence_space::local_space); if constexpr (MaxNumExperts > MaxNumExpertsUnit) { - constexpr int NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; - constexpr int NumInterTopK = NumExpertWarps * MaxNumTopExperts; - T* smemInterTopScores = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); - IdxT* smemInterTopExperts = *sycl::ext::oneapi::group_local_memory_for_overwrite(item.get_group()); - - if (warpIdx < NumExpertWarps) { - int32_t offset = warpIdx * WARP_SIZE * MaxNumTopGroups; - - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - int expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] = (offset + expertIdx < numExperts) - ? smemScoreBias[offset + expertIdx] - : invalidScoreT; - } - reduce_topk::reduceTopK( - group, topScores, topExperts, expertScoreGroup, expertIdxGroup, - invalidScoreT, static_cast(topk)); - - if (laneIdx < MaxNumTopExperts) { - if (laneIdx < topk) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; - } else { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreT; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1; - } - } + constexpr int NumExpertWarps = + (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; + constexpr int NumInterTopK = NumExpertWarps * MaxNumTopExperts; + T* smemInterTopScores = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + item.get_group()); + IdxT* smemInterTopExperts = + *sycl::ext::oneapi::group_local_memory_for_overwrite< + int32_t[NumInterTopK]>(item.get_group()); + + if (warpIdx < NumExpertWarps) { + int32_t offset = warpIdx * WARP_SIZE * MaxNumTopGroups; + + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + int expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = (offset + expertIdx < numExperts) + ? smemScoreBias[offset + expertIdx] + : invalidScoreT; + } + reduce_topk::reduceTopK( + group, + topScores, + topExperts, + expertScoreGroup, + expertIdxGroup, + invalidScoreT, + static_cast(topk)); + + if (laneIdx < MaxNumTopExperts) { + if (laneIdx < topk) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = + topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = + topExperts[laneIdx]; + } else { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = + invalidScoreT; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = + MaxNumExperts - 1; + } } - item.barrier(sycl::access::fence_space::local_space); - if (warpIdx == 0) { - constexpr int NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; - T intermediateScore[NumInterTopKPerThread]; - int32_t intermediateExpert[NumInterTopKPerThread]; - T invalidScoreT = neg_inf(); - - for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) { - int ii = i / WARP_SIZE; - if (i < NumInterTopK) { - intermediateScore[ii] = smemInterTopScores[i]; - intermediateExpert[ii] = smemInterTopExperts[i]; - } else { - intermediateScore[ii] = invalidScoreT; - intermediateExpert[ii] = MaxNumExperts - 1; - } - } - - reduce_topk::reduceTopK( - group, topScores, topExperts, intermediateScore, intermediateExpert, - invalidScoreT, static_cast(topk)); + } + item.barrier(sycl::access::fence_space::local_space); + if (warpIdx == 0) { + constexpr int NumInterTopKPerThread = + (NumInterTopK - 1) / WARP_SIZE + 1; + T intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; + T invalidScoreT = neg_inf(); + + for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; + i += WARP_SIZE) { + int ii = i / WARP_SIZE; + if (i < NumInterTopK) { + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; + } else { + intermediateScore[ii] = invalidScoreT; + intermediateExpert[ii] = MaxNumExperts - 1; + } } + + reduce_topk::reduceTopK( + group, + topScores, + topExperts, + intermediateScore, + intermediateExpert, + invalidScoreT, + static_cast(topk)); + } } else { - if (warpIdx == 0) { - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - int32_t expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] = (expertIdx < numExperts) - ? smemScoreBias[expertIdx] - : invalidScoreT; - } - reduce_topk::reduceTopK( - group, topScores, topExperts, expertScoreGroup, expertIdxGroup, - invalidScoreT, static_cast(topk)); + if (warpIdx == 0) { + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + int32_t expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = (expertIdx < numExperts) + ? smemScoreBias[expertIdx] + : invalidScoreT; } + reduce_topk::reduceTopK( + group, + topScores, + topExperts, + expertScoreGroup, + expertIdxGroup, + invalidScoreT, + static_cast(topk)); + } } if (warpIdx == 0) { - int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; - T temp; - xpu::from_float(temp, 0.F); - T scoreNormT = laneIdx < topk ? smemScoreSigmoid[expertIdx] : temp; - float scoreNorm = xpu::to_float(scoreNormT); - float finalScore = static_cast(scoreNorm * routedScalingFactor); - float topk_sum = 1e-20f; - if (renormalize) { - topk_sum += sycl::reduce_over_group(group, scoreNorm,sycl::plus()); - finalScore /= topk_sum; - } - if (laneIdx < topk) { - topkValues[laneIdx] = finalScore; - topkIndices[laneIdx] = expertIdx; - } + int32_t expertIdx = + laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + T temp; + xpu::from_float(temp, 0.F); + T scoreNormT = laneIdx < topk ? smemScoreSigmoid[expertIdx] : temp; + float scoreNorm = xpu::to_float(scoreNormT); + float finalScore = static_cast(scoreNorm * routedScalingFactor); + float topk_sum = 1e-20f; + if (renormalize) { + topk_sum += + sycl::reduce_over_group(group, scoreNorm, sycl::plus()); + finalScore /= topk_sum; + } + if (laneIdx < topk) { + topkValues[laneIdx] = finalScore; + topkIndices[laneIdx] = expertIdx; + } } - } // end if constexpr (!UseGroups) + } // end if constexpr (!UseGroups) } template -void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices, - BiasT const* bias, int64_t const num_tokens, - int64_t const num_experts, int64_t const n_group, - int64_t const topk_group, int64_t const topk, - bool const renormalize, double const routed_scaling_factor, - bool enable_pdl = false, sycl::queue queue = sycl::queue()) { - int64_t experts_per_group = num_experts / n_group; - bool is_single_group = - (n_group == 1) && (topk_group == 1) && - (num_experts <= MaxSupportedExpertCount) && - (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); - - #define LAUNCH_SMALL_KERNEL(MAX_EXPERTS, USE_GROUPS, MAX_TOP_EXPERTS, NUM_THREADS) \ - do { \ - size_t local_size = static_cast(NUM_THREADS); \ - size_t global_size = static_cast(num_tokens) * local_size; \ - queue.submit([&](sycl::handler& cgh) { \ - cgh.parallel_for>( \ - sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), \ - [=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { \ - grouped_topk_fused_small_expert_count_kernel( \ - scores, topk_values, topk_indices, bias, \ - num_tokens, n_group, topk_group, topk, num_experts, \ - experts_per_group, renormalize, routed_scaling_factor, item); \ - }); \ - }); \ - } while (0) - - if (is_single_group) { - if (num_experts == NumNemotronExperts && n_group == 1 && - topk == MaxSupportedTopExperts) { - LAUNCH_SMALL_KERNEL(NumNemotronExperts, false, - MaxSupportedTopExperts, - ((NumNemotronExperts + MaxNumExpertsUnit - 1) / - MaxNumExpertsUnit) * WARP_SIZE); - } else if (num_experts > NumKimiK2Experts && - num_experts <= MaxSupportedExpertCount) { - LAUNCH_SMALL_KERNEL(MaxSupportedExpertCount, false, - DefaultMaxNumTopExperts, - ((MaxSupportedExpertCount + MaxNumExpertsUnit - 1) / - MaxNumExpertsUnit) * WARP_SIZE); - } else if (num_experts > MaxNumExpertsUnit && - num_experts <= NumKimiK2Experts) { - LAUNCH_SMALL_KERNEL(NumKimiK2Experts, false, - DefaultMaxNumTopExperts, - ((NumKimiK2Experts + MaxNumExpertsUnit - 1) / - MaxNumExpertsUnit) * WARP_SIZE); - } else { - LAUNCH_SMALL_KERNEL(MaxNumExpertsUnit, false, - DefaultMaxNumTopExperts, - WARP_SIZE); - } +void invokeNoAuxTc( + T* scores, + float* topk_values, + IdxT* topk_indices, + BiasT const* bias, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double const routed_scaling_factor, + bool enable_pdl = false, + sycl::queue queue = sycl::queue()) { + int64_t experts_per_group = num_experts / n_group; + bool is_single_group = + (n_group == 1) && (topk_group == 1) && + (num_experts <= MaxSupportedExpertCount) && + (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); + +#define LAUNCH_SMALL_KERNEL( \ + MAX_EXPERTS, USE_GROUPS, MAX_TOP_EXPERTS, NUM_THREADS) \ + do { \ + size_t local_size = static_cast(NUM_THREADS); \ + size_t global_size = static_cast(num_tokens) * local_size; \ + queue.submit([&](sycl::handler& cgh) { \ + cgh.parallel_for>( \ + sycl::nd_range<1>( \ + sycl::range<1>(global_size), sycl::range<1>(local_size)), \ + [=](sycl::nd_item<1> item) \ + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { \ + grouped_topk_fused_small_expert_count_kernel< \ + T, \ + BiasT, \ + IdxT, \ + SF, \ + MAX_EXPERTS, \ + USE_GROUPS, \ + MAX_TOP_EXPERTS>( \ + scores, \ + topk_values, \ + topk_indices, \ + bias, \ + num_tokens, \ + n_group, \ + topk_group, \ + topk, \ + num_experts, \ + experts_per_group, \ + renormalize, \ + routed_scaling_factor, \ + item); \ + }); \ + }); \ + } while (0) + + if (is_single_group) { + if (num_experts == NumNemotronExperts && n_group == 1 && + topk == MaxSupportedTopExperts) { + LAUNCH_SMALL_KERNEL( + NumNemotronExperts, + false, + MaxSupportedTopExperts, + ((NumNemotronExperts + MaxNumExpertsUnit - 1) / MaxNumExpertsUnit) * + WARP_SIZE); + } else if ( + num_experts > NumKimiK2Experts && + num_experts <= MaxSupportedExpertCount) { + LAUNCH_SMALL_KERNEL( + MaxSupportedExpertCount, + false, + DefaultMaxNumTopExperts, + ((MaxSupportedExpertCount + MaxNumExpertsUnit - 1) / + MaxNumExpertsUnit) * + WARP_SIZE); + } else if ( + num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts) { + LAUNCH_SMALL_KERNEL( + NumKimiK2Experts, + false, + DefaultMaxNumTopExperts, + ((NumKimiK2Experts + MaxNumExpertsUnit - 1) / MaxNumExpertsUnit) * + WARP_SIZE); } else { - LAUNCH_SMALL_KERNEL(NumDeepseekExperts, true, - DefaultMaxNumTopExperts, - WARP_SIZE); + LAUNCH_SMALL_KERNEL( + MaxNumExpertsUnit, false, DefaultMaxNumTopExperts, WARP_SIZE); } + } else { + LAUNCH_SMALL_KERNEL( + NumDeepseekExperts, true, DefaultMaxNumTopExperts, WARP_SIZE); + } - #undef LAUNCH_SMALL_KERNEL - +#undef LAUNCH_SMALL_KERNEL } -#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \ - template void invokeNoAuxTc( \ - T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \ - int64_t const num_tokens, int64_t const num_experts, \ - int64_t const n_group, int64_t const topk_group, int64_t const topk, \ - bool const renormalize, double const routed_scaling_factor, \ - bool enable_pdl, sycl::queue queue); +#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \ + template void invokeNoAuxTc( \ + T * scores, \ + float* topk_values, \ + IdxT* topk_indices, \ + BiasT const* bias, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + bool enable_pdl, \ + sycl::queue queue); INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID); INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_SIGMOID); INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_SIGMOID); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, + sycl::ext::oneapi::bfloat16, + int32_t, + SCORING_SIGMOID); INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE); INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_NONE); INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_NONE); INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC( + sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_NONE); -INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, + sycl::ext::oneapi::bfloat16, + int32_t, + SCORING_NONE); } // end namespace moe } // namespace vllm @@ -520,65 +649,85 @@ std::tuple fused_grouped_topk( c10::string_view const scoring_func, double const routed_scaling_factor, c10::optional const& bias) { - auto data_type = gating_output.scalar_type(); - bool has_bias = bias.has_value() && bias->defined(); - auto bias_type = has_bias ? bias->scalar_type() : torch::kFloat32; - auto input_size = gating_output.sizes(); - int64_t num_tokens = input_size[0]; - int64_t num_experts = input_size[1]; - int64_t n_group = n_expert_group; - int64_t topk_group = n_topk_group; - int64_t topk = n_topk; - - TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0], - "Number of tokens mismatch"); - TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor"); - TORCH_CHECK(n_group > 0, "n_group must be positive"); - TORCH_CHECK(topk > 0, "topk must be positive"); - TORCH_CHECK(topk_group > 0, "topk_group must be positive"); - TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group"); - TORCH_CHECK(num_experts % n_group == 0, - "num_experts should be divisible by n_group"); - TORCH_CHECK(n_group <= 32, - "n_group should be smaller than or equal to 32 for now"); - TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); - TORCH_CHECK(topk <= topk_group * (num_experts / n_group), - "topk must be <= topk_group * (num_experts / n_group)"); - TORCH_CHECK(scoring_func == "sigmoid" || scoring_func == "softmax", - "Unsupported scoring_func: ", scoring_func); - auto const sf = (scoring_func == "sigmoid") - ? vllm::moe::SCORING_SIGMOID - : vllm::moe::SCORING_NONE; - - - torch::Tensor topk_values = torch::empty( - {num_tokens, topk}, torch::dtype(torch::kFloat32).device(gating_output.device())); - torch::Tensor topk_indices = torch::empty( - {num_tokens, topk}, torch::dtype(torch::kInt32).device(gating_output.device())); - - auto device_idx = gating_output.device().index(); - auto stream = c10::xpu::getCurrentXPUStream(device_idx).queue(); + auto data_type = gating_output.scalar_type(); + bool has_bias = bias.has_value() && bias->defined(); + auto bias_type = has_bias ? bias->scalar_type() : torch::kFloat32; + auto input_size = gating_output.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + int64_t n_group = n_expert_group; + int64_t topk_group = n_topk_group; + int64_t topk = n_topk; + + TORCH_CHECK( + hidden_states.sizes()[0] == gating_output.sizes()[0], + "Number of tokens mismatch"); + TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor"); + TORCH_CHECK(n_group > 0, "n_group must be positive"); + TORCH_CHECK(topk > 0, "topk must be positive"); + TORCH_CHECK(topk_group > 0, "topk_group must be positive"); + TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group"); + TORCH_CHECK( + num_experts % n_group == 0, "num_experts should be divisible by n_group"); + TORCH_CHECK( + n_group <= 32, "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + TORCH_CHECK( + topk <= topk_group * (num_experts / n_group), + "topk must be <= topk_group * (num_experts / n_group)"); + TORCH_CHECK( + scoring_func == "sigmoid" || scoring_func == "softmax", + "Unsupported scoring_func: ", + scoring_func); + auto const sf = (scoring_func == "sigmoid") ? vllm::moe::SCORING_SIGMOID + : vllm::moe::SCORING_NONE; + + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, + torch::dtype(torch::kFloat32).device(gating_output.device())); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, + torch::dtype(torch::kInt32).device(gating_output.device())); + + auto device_idx = gating_output.device().index(); + auto stream = c10::xpu::getCurrentXPUStream(device_idx).queue(); #define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \ do { \ switch (sf) { \ case vllm::moe::SCORING_NONE: \ vllm::moe::invokeNoAuxTc( \ - reinterpret_cast(gating_output.mutable_data_ptr()), \ + reinterpret_cast(gating_output.mutable_data_ptr()), \ reinterpret_cast(topk_values.mutable_data_ptr()), \ reinterpret_cast(topk_indices.mutable_data_ptr()), \ - (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ - num_experts, n_group, topk_group, topk, renormalize, \ - routed_scaling_factor, false, stream); \ + (has_bias ? reinterpret_cast(bias->data_ptr()) \ + : nullptr), \ + num_tokens, \ + num_experts, \ + n_group, \ + topk_group, \ + topk, \ + renormalize, \ + routed_scaling_factor, \ + false, \ + stream); \ break; \ case vllm::moe::SCORING_SIGMOID: \ vllm::moe::invokeNoAuxTc( \ - reinterpret_cast(gating_output.mutable_data_ptr()), \ + reinterpret_cast(gating_output.mutable_data_ptr()), \ reinterpret_cast(topk_values.mutable_data_ptr()), \ reinterpret_cast(topk_indices.mutable_data_ptr()), \ - (has_bias ? reinterpret_cast(bias->data_ptr()) : nullptr), num_tokens, \ - num_experts, n_group, topk_group, topk, renormalize, \ - routed_scaling_factor, false, stream); \ + (has_bias ? reinterpret_cast(bias->data_ptr()) \ + : nullptr), \ + num_tokens, \ + num_experts, \ + n_group, \ + topk_group, \ + topk, \ + renormalize, \ + routed_scaling_factor, \ + false, \ + stream); \ break; \ default: \ throw std::invalid_argument("Unsupported scoring_func"); \ @@ -586,27 +735,25 @@ std::tuple fused_grouped_topk( } \ } while (0) -#define LAUNCH_KERNEL(T, IdxT) \ - do{ \ - switch (bias_type) { \ - case torch::kFloat16: \ - LAUNCH_KERNEL_SF(T, sycl::half, IdxT); \ - break; \ - case torch::kFloat32: \ - LAUNCH_KERNEL_SF(T, float, IdxT); \ - break; \ - case torch::kBFloat16: \ - LAUNCH_KERNEL_SF(T, sycl::ext::oneapi::bfloat16, IdxT); \ - break; \ - default: \ - throw std::invalid_argument( \ - "Invalid bias dtype, only supports float16, float32, and " \ - "bfloat16"); \ - break; \ - } \ - } \ - while (0) - +#define LAUNCH_KERNEL(T, IdxT) \ + do { \ + switch (bias_type) { \ + case torch::kFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::half, IdxT); \ + break; \ + case torch::kFloat32: \ + LAUNCH_KERNEL_SF(T, float, IdxT); \ + break; \ + case torch::kBFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::ext::oneapi::bfloat16, IdxT); \ + break; \ + default: \ + throw std::invalid_argument( \ + "Invalid bias dtype, only supports float16, float32, and " \ + "bfloat16"); \ + break; \ + } \ + } while (0) switch (data_type) { case torch::kFloat16: diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h old mode 100755 new mode 100644 index 09cefdb53..9061839a2 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -56,7 +56,6 @@ std::tuple fused_grouped_topk( const double routed_scaling_factor, const c10::optional& bias); - void topk_softmax( torch::Tensor& topk_weights, torch::Tensor& topk_indices, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 57015ca13..fa3f4e07e 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -47,7 +47,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor? maybe_expert_map) -> () "); m.impl("moe_lora_align_block_size", torch::kXPU, &moe_lora_align_block_size); - // Apply grouped topk routing to select experts. m.def( "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " diff --git a/csrc/utils.h b/csrc/utils.h index 6f217b359..0b24dfba4 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -107,7 +107,6 @@ struct SyclTypeTrait { using Type = sycl::ext::oneapi::bfloat16; }; - template struct AccumulateType { private: diff --git a/tests/ops/grouped_topk_op.py b/tests/ops/grouped_topk_op.py index b3b391860..71c698c5f 100644 --- a/tests/ops/grouped_topk_op.py +++ b/tests/ops/grouped_topk_op.py @@ -20,7 +20,8 @@ def grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": @@ -34,35 +35,35 @@ def grouped_topk( # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) else: - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - # For batch invariance, use sorted=True to ensure deterministic expert selection + group_scores = (scores.view(num_token, num_expert_group, + -1).max(dim=-1).values) # [n, n_group] + # For batch invariance, use sorted=True to ensure + # deterministic expert selection use_sorted = True - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ - 1 - ] # [n, top_k_group] + group_idx = torch.topk(group_scores, + k=topk_group, + dim=-1, + sorted=use_sorted)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + score_mask = (group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.size(-1) // num_expert_group).reshape(num_token, -1)) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk( - tmp_scores, k=topk, dim=-1, sorted=use_sorted - ) + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=use_sorted) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -72,7 +73,6 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) - def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -96,8 +96,14 @@ def fused_grouped_topk( scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) topk_values, topk_indices = ops.grouped_topk( - scores, scores_with_bias.to(scores.dtype), num_expert_group, - topk_group, topk, renormalize, routed_scaling_factor) + scores, + scores_with_bias.to(scores.dtype), + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) return topk_values.to(torch.float32), topk_indices.to(torch.int32) @@ -118,11 +124,16 @@ def fused_grouped_topk_sycl( scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": scores = gating_output - else: + else: raise ValueError(f"Unsupported scoring function: {scoring_func}") - return ops.fused_grouped_topk(hidden_states, scores, topk, - renormalize, num_expert_group, topk_group, - scoring_func, routed_scaling_factor, - e_score_correction_bias) - - + return ops.fused_grouped_topk( + hidden_states, + scores, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func, + routed_scaling_factor, + e_score_correction_bias, + ) From c6d67902a224e872dfd842be779af95003f4ae78 Mon Sep 17 00:00:00 2001 From: xiaolong Date: Wed, 29 Apr 2026 06:33:02 +0000 Subject: [PATCH 11/15] Revert unintended oneDNN and moe_ops.h changes Signed-off-by: xiaolong Signed-off-by: root --- csrc/moe/moe_ops.h | 0 tests/test_grouped_topk.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 csrc/moe/moe_ops.h diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h old mode 100644 new mode 100755 diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py index 82968dcb9..ec383a8a0 100644 --- a/tests/test_grouped_topk.py +++ b/tests/test_grouped_topk.py @@ -90,4 +90,4 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, torch.testing.assert_close(baseline_topk_ids, test_topk_ids_sycl, atol=0, - rtol=0) \ No newline at end of file + rtol=0) From de7497fb16e21f062411bed173db4687acb6734a Mon Sep 17 00:00:00 2001 From: root Date: Thu, 21 May 2026 05:52:47 +0000 Subject: [PATCH 12/15] add optimization for single group Signed-off-by: root --- csrc/moe/fused_grouped_topk.cpp | 318 +++++++++++++++++--------------- 1 file changed, 172 insertions(+), 146 deletions(-) diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index ea619c771..defb25158 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -33,7 +33,7 @@ template < typename IdxT, ScoringFunc SF, int MaxNumExperts, - bool UseGroups, + bool MultiGroups, int MaxNumTopExperts = DefaultMaxNumTopExperts> class VllmGroupedTopKFusedSmallExpertCountKernel; @@ -87,7 +87,7 @@ inline void reduceTopK( T min_val, int topk) { constexpr IdxT invalid_idx = std::numeric_limits::max(); - bool selected[N_IN] = {false}; + bool selected[N_IN]; for (int k = 0; k < topk; ++k) { T local_best_val = min_val; @@ -175,7 +175,7 @@ template < typename IdxT, ScoringFunc SF, int MaxNumExperts, - bool UseGroups, + bool MultiGroups, int MaxNumTopExperts = DefaultMaxNumTopExperts> SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( T* scores, @@ -191,12 +191,11 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( bool const renormalize, double const routedScalingFactor, sycl::nd_item<1> item) { - constexpr int NumWarps = MaxNumExperts / WARP_SIZE; constexpr float invalidScoreFloat = -std::numeric_limits::infinity(); int threadIdx = item.get_local_id(0); int blockIdx = item.get_group(0); - if constexpr (UseGroups) { + if constexpr (MultiGroups) { if (blockIdx >= numTokens) return; } int localSize = item.get_local_range(0); @@ -208,7 +207,7 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( topkValues += blockIdx * topk; topkIndices += blockIdx * topk; - if constexpr (UseGroups) { + if constexpr (MultiGroups) { auto subgroup = item.get_sub_group(); T* scoresToken = scores + static_cast(blockIdx) * numExperts; T selectedGroupScores[WARP_SIZE]; @@ -345,144 +344,184 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( } return; } else { - T* smemScoreSigmoid = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - item.get_group()); - T* smemScoreBias = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - item.get_group()); - T invalidScoreT = neg_inf(); - T topScores[MaxNumTopExperts] = {neg_inf()}; - int32_t topExperts[MaxNumTopExperts] = {0}; - T expertScoreGroup[MaxNumTopGroups] = {neg_inf()}; - int32_t expertIdxGroup[MaxNumTopGroups] = {0}; - auto group = item.get_sub_group(); - - for (int expert = threadIdx; expert < numExperts; expert += localSize) { - int64_t scoreIdx = int64_t{blockIdx} * int64_t{numExperts} + expert; - T score = scores[scoreIdx]; - T scoreSigmoid = apply_scoring(score); - smemScoreSigmoid[expert] = scoreSigmoid; - smemScoreBias[expert] = - has_bias ? (scoreSigmoid + sycl_cast(routingBias[expert])) - : scoreSigmoid; - } + // Single-group path: select top-k from all experts using single warp. + constexpr int ExpertsPerLane = MaxNumExperts / WARP_SIZE; + auto subgroup = item.get_sub_group(); + T* scoresToken = scores + static_cast(blockIdx) * numExperts; - item.barrier(sycl::access::fence_space::local_space); - - if constexpr (MaxNumExperts > MaxNumExpertsUnit) { - constexpr int NumExpertWarps = - (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; - constexpr int NumInterTopK = NumExpertWarps * MaxNumTopExperts; - T* smemInterTopScores = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - item.get_group()); - IdxT* smemInterTopExperts = - *sycl::ext::oneapi::group_local_memory_for_overwrite< - int32_t[NumInterTopK]>(item.get_group()); - - if (warpIdx < NumExpertWarps) { - int32_t offset = warpIdx * WARP_SIZE * MaxNumTopGroups; - - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - int expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] = (offset + expertIdx < numExperts) - ? smemScoreBias[offset + expertIdx] - : invalidScoreT; - } - reduce_topk::reduceTopK( - group, - topScores, - topExperts, - expertScoreGroup, - expertIdxGroup, - invalidScoreT, - static_cast(topk)); - - if (laneIdx < MaxNumTopExperts) { - if (laneIdx < topk) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = - topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = - topExperts[laneIdx]; - } else { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = - invalidScoreT; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = - MaxNumExperts - 1; - } + if constexpr (MaxNumExperts <= NumKimiK2Experts) { + // Direct path: <=384 experts, fits in registers + T localBiasedScores[ExpertsPerLane]; + IdxT localIdx[ExpertsPerLane]; + + for (int i = 0; i < ExpertsPerLane; ++i) { + int expertId = i * WARP_SIZE + laneIdx; + if (expertId < numExperts) { + T raw = scoresToken[expertId]; + T scored = apply_scoring(raw); + localBiasedScores[i] = + has_bias ? (scored + sycl_cast(routingBias[expertId])) + : scored; + } else { + localBiasedScores[i] = neg_inf(); } + localIdx[i] = static_cast(expertId); } - item.barrier(sycl::access::fence_space::local_space); - if (warpIdx == 0) { - constexpr int NumInterTopKPerThread = - (NumInterTopK - 1) / WARP_SIZE + 1; - T intermediateScore[NumInterTopKPerThread]; - int32_t intermediateExpert[NumInterTopKPerThread]; - T invalidScoreT = neg_inf(); - - for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; - i += WARP_SIZE) { - int ii = i / WARP_SIZE; - if (i < NumInterTopK) { - intermediateScore[ii] = smemInterTopScores[i]; - intermediateExpert[ii] = smemInterTopExperts[i]; - } else { - intermediateScore[ii] = invalidScoreT; - intermediateExpert[ii] = MaxNumExperts - 1; - } + + T selectedScores[MaxNumTopExperts]; + IdxT selectedIdx[MaxNumTopExperts]; + reduce_topk::reduceTopK( + subgroup, + selectedScores, + selectedIdx, + localBiasedScores, + localIdx, + neg_inf(), + static_cast(topk)); + + float unbiasedArr[DefaultMaxNumTopExperts]; + for (int i = 0; i < static_cast(topk); ++i) { + IdxT idx_i = selectedIdx[i]; + T raw = scoresToken[static_cast(idx_i)]; + T scored = apply_scoring(raw); + unbiasedArr[i] = xpu::to_float(scored); + } + + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = selectedIdx[laneIdx]; + laneUnbiased = unbiasedArr[laneIdx]; + } + + float scale = static_cast(routedScalingFactor); + if (renormalize) { + float topk_sum = 0.0f; + for (int i = 0; i < static_cast(topk); ++i) { + topk_sum += unbiasedArr[i]; + } + if (laneIdx < topk) { + laneUnbiased /= topk_sum; } + } - reduce_topk::reduceTopK( - group, - topScores, - topExperts, - intermediateScore, - intermediateExpert, - invalidScoreT, - static_cast(topk)); + if (laneIdx < topk) { + topkValues[laneIdx] = laneUnbiased * scale; + topkIndices[laneIdx] = laneIdxOut; } } else { - if (warpIdx == 0) { - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - int32_t expertIdx = ii * WARP_SIZE + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] = (expertIdx < numExperts) - ? smemScoreBias[expertIdx] - : invalidScoreT; + // Large expert path (>256): Each lane streams through its experts, + // maintaining a local top-k via insertion sort, then iteratively + // selects global top-k across the warp. Minimal register pressure. + constexpr int LocalTopK = DefaultMaxNumTopExperts; // 8 + auto subgroup = item.get_sub_group(); + T* scoresToken2 = scores + static_cast(blockIdx) * numExperts; + + // Each lane maintains its local top-k sorted descending + float laneTopVal[LocalTopK]; + IdxT laneTopIdx[LocalTopK]; + for (int k = 0; k < LocalTopK; ++k) { + laneTopVal[k] = -std::numeric_limits::infinity(); + laneTopIdx[k] = std::numeric_limits::max(); + } + + // Stream through all experts assigned to this lane + for (int i = 0; i < ExpertsPerLane; ++i) { + int expertId = i * WARP_SIZE + laneIdx; + if (expertId >= numExperts) break; + + T raw = scoresToken2[expertId]; + T scored = apply_scoring(raw); + float val = + has_bias ? xpu::to_float( + scored + sycl_cast(routingBias[expertId])) + : xpu::to_float(scored); + IdxT idx = static_cast(expertId); + + // Check if this beats the worst in our local top-k + if (val > laneTopVal[LocalTopK - 1] || + (val == laneTopVal[LocalTopK - 1] && + idx < laneTopIdx[LocalTopK - 1])) { + // Insert into sorted position + int pos = LocalTopK - 1; + for (int j = LocalTopK - 2; j >= 0; --j) { + if (val > laneTopVal[j] || + (val == laneTopVal[j] && idx < laneTopIdx[j])) { + laneTopVal[j + 1] = laneTopVal[j]; + laneTopIdx[j + 1] = laneTopIdx[j]; + pos = j; + } else { + break; + } + } + laneTopVal[pos] = val; + laneTopIdx[pos] = idx; } - reduce_topk::reduceTopK( - group, - topScores, - topExperts, - expertScoreGroup, - expertIdxGroup, - invalidScoreT, - static_cast(topk)); } - } - if (warpIdx == 0) { - int32_t expertIdx = - laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; - T temp; - xpu::from_float(temp, 0.F); - T scoreNormT = laneIdx < topk ? smemScoreSigmoid[expertIdx] : temp; - float scoreNorm = xpu::to_float(scoreNormT); - float finalScore = static_cast(scoreNorm * routedScalingFactor); - float topk_sum = 1e-20f; + // Now iteratively select global top-k from the warp. + // Each lane offers its current best (top of local sorted list). + // The warp finds the global max, the winning lane pops it. + IdxT globalTopIdx[DefaultMaxNumTopExperts]; + int lanePtr = 0; // points to next candidate in laneTopVal + + for (int k = 0; k < static_cast(topk); ++k) { + float myVal = (lanePtr < LocalTopK) + ? laneTopVal[lanePtr] + : -std::numeric_limits::infinity(); + IdxT myIdx = (lanePtr < LocalTopK) ? laneTopIdx[lanePtr] + : std::numeric_limits::max(); + + // Find the best value across all lanes + float bestVal = + sycl::reduce_over_group(subgroup, myVal, sycl::maximum()); + + // Among lanes that have bestVal, pick smallest idx + IdxT candidateIdx = + (myVal == bestVal) ? myIdx : std::numeric_limits::max(); + IdxT bestIdx = sycl::reduce_over_group( + subgroup, candidateIdx, sycl::minimum()); + + globalTopIdx[k] = bestIdx; + + // The winning lane advances its pointer + if (myIdx == bestIdx && myVal == bestVal) { + lanePtr++; + } + } + + // Re-read unbiased scores for winners + float unbiasedArr[DefaultMaxNumTopExperts]; + for (int i = 0; i < static_cast(topk); ++i) { + T raw = scoresToken2[static_cast(globalTopIdx[i])]; + T scored = apply_scoring(raw); + unbiasedArr[i] = xpu::to_float(scored); + } + + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = globalTopIdx[laneIdx]; + laneUnbiased = unbiasedArr[laneIdx]; + } + + float scale = static_cast(routedScalingFactor); if (renormalize) { - topk_sum += - sycl::reduce_over_group(group, scoreNorm, sycl::plus()); - finalScore /= topk_sum; + float topk_sum = 0.0f; + for (int i = 0; i < static_cast(topk); ++i) { + topk_sum += unbiasedArr[i]; + } + if (laneIdx < topk) { + laneUnbiased /= topk_sum; + } } + if (laneIdx < topk) { - topkValues[laneIdx] = finalScore; - topkIndices[laneIdx] = expertIdx; + topkValues[laneIdx] = laneUnbiased * scale; + topkIndices[laneIdx] = laneIdxOut; } } - } // end if constexpr (!UseGroups) + } // end if constexpr (!MultiGroups) } template @@ -553,29 +592,16 @@ void invokeNoAuxTc( if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts) { LAUNCH_SMALL_KERNEL( - NumNemotronExperts, - false, - MaxSupportedTopExperts, - ((NumNemotronExperts + MaxNumExpertsUnit - 1) / MaxNumExpertsUnit) * - WARP_SIZE); + NumNemotronExperts, false, MaxSupportedTopExperts, WARP_SIZE); } else if ( num_experts > NumKimiK2Experts && num_experts <= MaxSupportedExpertCount) { LAUNCH_SMALL_KERNEL( - MaxSupportedExpertCount, - false, - DefaultMaxNumTopExperts, - ((MaxSupportedExpertCount + MaxNumExpertsUnit - 1) / - MaxNumExpertsUnit) * - WARP_SIZE); + MaxSupportedExpertCount, false, DefaultMaxNumTopExperts, WARP_SIZE); } else if ( num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts) { LAUNCH_SMALL_KERNEL( - NumKimiK2Experts, - false, - DefaultMaxNumTopExperts, - ((NumKimiK2Experts + MaxNumExpertsUnit - 1) / MaxNumExpertsUnit) * - WARP_SIZE); + NumKimiK2Experts, false, DefaultMaxNumTopExperts, WARP_SIZE); } else { LAUNCH_SMALL_KERNEL( MaxNumExpertsUnit, false, DefaultMaxNumTopExperts, WARP_SIZE); @@ -773,4 +799,4 @@ std::tuple fused_grouped_topk( #undef LAUNCH_KERNEL #undef LAUNCH_KERNEL_SF return {topk_values, topk_indices}; -} +} \ No newline at end of file From 2bc9f69d55ac0c722a915011be6799192816d2f6 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 26 May 2026 05:26:55 +0000 Subject: [PATCH 13/15] change kernel launch method Signed-off-by: root --- csrc/moe/fused_grouped_topk.cpp | 44 +++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index defb25158..98a8e095a 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -58,6 +58,27 @@ inline float sigmoid_accurate(float x) { return 1.f / (1.f + sycl::native::exp(-x)); } + +template +inline T warp_reduce_max(sycl::sub_group sg, T val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + T other = sycl::select_from_group(sg, val, + (sg.get_local_linear_id() ^ offset)); + val = sycl::max(val, other); + } + return val; +} + +template +inline T warp_reduce_min(sycl::sub_group sg, T val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + T other = sycl::select_from_group(sg, val, + (sg.get_local_linear_id() ^ offset)); + val = sycl::min(val, other); + } + return val; +} + template inline T apply_sigmoid(T val) { float f = xpu::to_float(val); @@ -109,8 +130,7 @@ inline void reduceTopK( } } float local_best_val_tmp = xpu::to_float(local_best_val); - float warp_best_val_tmp = sycl::reduce_over_group( - subgroup, local_best_val_tmp, sycl::maximum()); + float warp_best_val_tmp = warp_reduce_max(subgroup, local_best_val_tmp); T warp_best_val = static_cast(warp_best_val_tmp); IdxT warp_best_idx = invalid_idx; @@ -118,8 +138,7 @@ inline void reduceTopK( if (local_best_pos != -1 && local_best_val == warp_best_val) { warp_best_idx = local_best_idx; } - warp_best_idx = - sycl::reduce_over_group(subgroup, warp_best_idx, sycl::minimum()); + warp_best_idx = warp_reduce_min(subgroup, warp_best_idx); bool found = (warp_best_idx != invalid_idx); if (found) { @@ -473,14 +492,12 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( : std::numeric_limits::max(); // Find the best value across all lanes - float bestVal = - sycl::reduce_over_group(subgroup, myVal, sycl::maximum()); + float bestVal = warp_reduce_max(subgroup, myVal); // Among lanes that have bestVal, pick smallest idx IdxT candidateIdx = (myVal == bestVal) ? myIdx : std::numeric_limits::max(); - IdxT bestIdx = sycl::reduce_over_group( - subgroup, candidateIdx, sycl::minimum()); + IdxT bestIdx = warp_reduce_min(subgroup, candidateIdx); globalTopIdx[k] = bestIdx; @@ -537,8 +554,8 @@ void invokeNoAuxTc( int64_t const topk, bool const renormalize, double const routed_scaling_factor, - bool enable_pdl = false, - sycl::queue queue = sycl::queue()) { + bool enable_pdl, + sycl::queue& queue) { int64_t experts_per_group = num_experts / n_group; bool is_single_group = (n_group == 1) && (topk_group == 1) && @@ -628,7 +645,7 @@ void invokeNoAuxTc( bool const renormalize, \ double const routed_scaling_factor, \ bool enable_pdl, \ - sycl::queue queue); + sycl::queue& queue); INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID); INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_SIGMOID); @@ -715,8 +732,9 @@ std::tuple fused_grouped_topk( {num_tokens, topk}, torch::dtype(torch::kInt32).device(gating_output.device())); - auto device_idx = gating_output.device().index(); - auto stream = c10::xpu::getCurrentXPUStream(device_idx).queue(); + // auto device_idx = gating_output.device().index(); + auto device = gating_output.device(); + auto& stream = vllm::xpu::vllmGetQueue(device.index()); #define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \ do { \ From 3a551f98e4ed168c791e2a8f72cbc012336f531c Mon Sep 17 00:00:00 2001 From: root Date: Tue, 26 May 2026 07:04:26 +0000 Subject: [PATCH 14/15] pre-commit check Signed-off-by: root --- csrc/moe/fused_grouped_topk.cpp | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index 98a8e095a..cd8d00fa1 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -58,25 +58,24 @@ inline float sigmoid_accurate(float x) { return 1.f / (1.f + sycl::native::exp(-x)); } - template inline T warp_reduce_max(sycl::sub_group sg, T val) { - for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { - T other = sycl::select_from_group(sg, val, - (sg.get_local_linear_id() ^ offset)); - val = sycl::max(val, other); - } - return val; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + T other = + sycl::select_from_group(sg, val, (sg.get_local_linear_id() ^ offset)); + val = sycl::max(val, other); + } + return val; } template inline T warp_reduce_min(sycl::sub_group sg, T val) { - for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { - T other = sycl::select_from_group(sg, val, - (sg.get_local_linear_id() ^ offset)); - val = sycl::min(val, other); - } - return val; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + T other = + sycl::select_from_group(sg, val, (sg.get_local_linear_id() ^ offset)); + val = sycl::min(val, other); + } + return val; } template From 01db69167b9304ccc6f27c93d5e4d1b77857a987 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 27 May 2026 02:04:39 +0000 Subject: [PATCH 15/15] Correct the operator results to make them consistent with torch Signed-off-by: root --- csrc/moe/fused_grouped_topk.cpp | 41 +++++++++++---------------------- tests/register_ops.py | 4 ++-- tests/test_grouped_topk.py | 2 +- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index cd8d00fa1..7600da8e6 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -58,26 +58,6 @@ inline float sigmoid_accurate(float x) { return 1.f / (1.f + sycl::native::exp(-x)); } -template -inline T warp_reduce_max(sycl::sub_group sg, T val) { - for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { - T other = - sycl::select_from_group(sg, val, (sg.get_local_linear_id() ^ offset)); - val = sycl::max(val, other); - } - return val; -} - -template -inline T warp_reduce_min(sycl::sub_group sg, T val) { - for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { - T other = - sycl::select_from_group(sg, val, (sg.get_local_linear_id() ^ offset)); - val = sycl::min(val, other); - } - return val; -} - template inline T apply_sigmoid(T val) { float f = xpu::to_float(val); @@ -107,7 +87,7 @@ inline void reduceTopK( T min_val, int topk) { constexpr IdxT invalid_idx = std::numeric_limits::max(); - bool selected[N_IN]; + bool selected[N_IN] = {false}; for (int k = 0; k < topk; ++k) { T local_best_val = min_val; @@ -129,7 +109,8 @@ inline void reduceTopK( } } float local_best_val_tmp = xpu::to_float(local_best_val); - float warp_best_val_tmp = warp_reduce_max(subgroup, local_best_val_tmp); + float warp_best_val_tmp = sycl::reduce_over_group( + subgroup, local_best_val_tmp, sycl::maximum()); T warp_best_val = static_cast(warp_best_val_tmp); IdxT warp_best_idx = invalid_idx; @@ -137,7 +118,8 @@ inline void reduceTopK( if (local_best_pos != -1 && local_best_val == warp_best_val) { warp_best_idx = local_best_idx; } - warp_best_idx = warp_reduce_min(subgroup, warp_best_idx); + warp_best_idx = + sycl::reduce_over_group(subgroup, warp_best_idx, sycl::minimum()); bool found = (warp_best_idx != invalid_idx); if (found) { @@ -262,6 +244,7 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( static_cast(topkGroup)); bool proceed = false; + if (topkGroup > 0) { proceed = (selectedGroupScores[topkGroup - 1] != neg_inf()); } @@ -491,12 +474,14 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( : std::numeric_limits::max(); // Find the best value across all lanes - float bestVal = warp_reduce_max(subgroup, myVal); + float bestVal = + sycl::reduce_over_group(subgroup, myVal, sycl::maximum()); // Among lanes that have bestVal, pick smallest idx IdxT candidateIdx = (myVal == bestVal) ? myIdx : std::numeric_limits::max(); - IdxT bestIdx = warp_reduce_min(subgroup, candidateIdx); + IdxT bestIdx = sycl::reduce_over_group( + subgroup, candidateIdx, sycl::minimum()); globalTopIdx[k] = bestIdx; @@ -562,7 +547,7 @@ void invokeNoAuxTc( (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); #define LAUNCH_SMALL_KERNEL( \ - MAX_EXPERTS, USE_GROUPS, MAX_TOP_EXPERTS, NUM_THREADS) \ + MAX_EXPERTS, MultiGroups, MAX_TOP_EXPERTS, NUM_THREADS) \ do { \ size_t local_size = static_cast(NUM_THREADS); \ size_t global_size = static_cast(num_tokens) * local_size; \ @@ -573,7 +558,7 @@ void invokeNoAuxTc( IdxT, \ SF, \ MAX_EXPERTS, \ - USE_GROUPS, \ + MultiGroups, \ MAX_TOP_EXPERTS>>( \ sycl::nd_range<1>( \ sycl::range<1>(global_size), sycl::range<1>(local_size)), \ @@ -585,7 +570,7 @@ void invokeNoAuxTc( IdxT, \ SF, \ MAX_EXPERTS, \ - USE_GROUPS, \ + MultiGroups, \ MAX_TOP_EXPERTS>( \ scores, \ topk_values, \ diff --git a/tests/register_ops.py b/tests/register_ops.py index 0508ed06d..136ae8e22 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -458,7 +458,7 @@ def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, - topk: int, + n_topk: int, renormalize: bool, num_expert_group: int, topk_group: int, @@ -467,7 +467,7 @@ def fused_grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None, ): return torch.ops._moe_C.fused_grouped_topk(hidden_states, gating_output, - topk, renormalize, + n_topk, renormalize, num_expert_group, topk_group, scoring_func, routed_scaling_factor, diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py index ec383a8a0..82968dcb9 100644 --- a/tests/test_grouped_topk.py +++ b/tests/test_grouped_topk.py @@ -90,4 +90,4 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, torch.testing.assert_close(baseline_topk_ids, test_topk_ids_sycl, atol=0, - rtol=0) + rtol=0) \ No newline at end of file