diff --git a/benchmark/benchmark_grouped_topk.py b/benchmark/benchmark_grouped_topk.py index 28bb2b034..ceb6530cf 100644 --- a/benchmark/benchmark_grouped_topk.py +++ b/benchmark/benchmark_grouped_topk.py @@ -209,4 +209,4 @@ def benchmark( args = parser.parse_args() benchmark = get_benchmark() - benchmark.run(print_data=True, save_path=args.save_path) + benchmark.run(print_data=True, save_path=args.save_path) \ No newline at end of file diff --git a/csrc/moe/fused_grouped_topk.cpp b/csrc/moe/fused_grouped_topk.cpp index 49471084e..7600da8e6 100644 --- a/csrc/moe/fused_grouped_topk.cpp +++ b/csrc/moe/fused_grouped_topk.cpp @@ -1,511 +1,804 @@ +#include +#include +#include +#include #include - -#include "../utils.h" #include "../dispatch_utils.h" - +#include "../utils.h" namespace vllm { -namespace GroupedTopKImpl { - -enum class ScoringFunc { - DEFAULT = 0, - SOFTMAX = 1, - SIGMOID = 2, -}; - -template -struct Fused_Grouped_Topk { - static constexpr int sub_group_size = 32; - static constexpr int max_group_size = 1024; - static constexpr int malloc_per_item = MAX_EXPERT_GROUPS; - static constexpr float kNegInfinity = INFINITY * -1; - - Fused_Grouped_Topk( - float* topk_weights, - int* topk_ids, - const T* gating_output, - const T* e_score_correction_bias, - const double routed_scaling_factor, - const ScoringFunc scoring_mode, - const bool renormalize, - const int tokens, - const int experts, - const int top_k, - const int num_expert_group, - const int topk_group) - : topk_weights(topk_weights), - topk_ids(topk_ids), - gating_output(gating_output), - e_score_correction_bias(e_score_correction_bias), - routed_scaling_factor(routed_scaling_factor), - scoring_mode(scoring_mode), - renormalize(renormalize), - tokens(tokens), - experts(experts), - top_k(top_k), - num_expert_group(num_expert_group), - topk_group(topk_group) {} - - static inline sycl::nd_range<3> - get_nd_range(const int tokens, const int experts) { - int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; - int group_size = (experts + calc_per_item - 1) / calc_per_item; - group_size = group_size < sub_group_size ? sub_group_size : group_size; - group_size = group_size < max_group_size ? group_size : max_group_size; - int sub_groups_per_group = - (group_size + sub_group_size - 1) / sub_group_size; - group_size = sub_groups_per_group * sub_group_size; - int global_size = - (tokens + sub_groups_per_group - 1) / sub_groups_per_group; - - sycl::range<3> local(1, 1, group_size); - sycl::range<3> global(1, 1, global_size); - return sycl::nd_range<3>(global * local, local); - } +namespace moe { + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +static constexpr int WARP_SIZE = 32; +static constexpr int NumNemotronExperts = 512; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = + std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); +static constexpr int MaxNumExpertsUnit = 128; +static constexpr int NumTopGroupScores = 2; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; +static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxReduceTopK = 32; + +enum ScoringFunc : int { SCORING_NONE = 0, SCORING_SIGMOID = 1 }; + +template +class VllmGroupedTopKFusedKernel; + +template < + typename T, + typename BiasT, + typename IdxT, + ScoringFunc SF, + int MaxNumExperts, + bool MultiGroups, + int MaxNumTopExperts = DefaultMaxNumTopExperts> +class VllmGroupedTopKFusedSmallExpertCountKernel; + +template +inline T_OUT sycl_cast(T_IN val) { + return static_cast(val); +} - static inline float Sigmoid(float x) { - return 1.0f / (1.0f + sycl::native::exp(-x)); - } +template +inline T neg_inf() { + T out; + xpu::from_float(out, -std::numeric_limits::infinity()); + return out; +} - [[sycl::reqd_sub_group_size(sub_group_size)]] void - operator()(sycl::nd_item<3> item) const { - int group_id = item.get_group_linear_id(); - int local_range = item.get_local_range(2); - int sub_groups_per_group = local_range / sub_group_size; - int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; +template +inline bool is_finite(const T val) { + return std::isfinite(xpu::to_float(val)); +} - int experts_per_group = experts / num_expert_group; +inline float sigmoid_accurate(float x) { + return 1.f / (1.f + sycl::native::exp(-x)); +} - sycl::sub_group sg = item.get_sub_group(); - int sg_id = sg.get_group_id(); - int sg_local_id = sg.get_local_id(); +template +inline T apply_sigmoid(T val) { + float f = xpu::to_float(val); + T out; + xpu::from_float(out, sigmoid_accurate(f)); + return out; +} - int tid = group_id * sub_groups_per_group + sg_id; +template +inline T apply_scoring(T val) { + if constexpr (SF == SCORING_NONE) { + return val; + } else if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } +} - if (tid >= tokens) { - return; // Out of bounds +namespace reduce_topk { + +template +inline void reduceTopK( + sycl::sub_group subgroup, + T* out_val, + IdxT* out_idx, + const T* in_vals, + const IdxT* in_idxs, + T min_val, + int topk) { + constexpr IdxT invalid_idx = std::numeric_limits::max(); + bool selected[N_IN] = {false}; + + for (int k = 0; k < topk; ++k) { + T local_best_val = min_val; + IdxT local_best_idx = invalid_idx; + int local_best_pos = -1; + +#pragma unroll + for (int i = 0; i < N_IN; ++i) { + if (selected[i]) { + continue; + } + T cand_val = in_vals[i]; + IdxT cand_idx = in_idxs[i]; + if ((cand_val > local_best_val) || + ((cand_val == local_best_val) && (cand_idx < local_best_idx))) { + local_best_val = cand_val; + local_best_idx = cand_idx; + local_best_pos = i; + } } + float local_best_val_tmp = xpu::to_float(local_best_val); + float warp_best_val_tmp = sycl::reduce_over_group( + subgroup, local_best_val_tmp, sycl::maximum()); - T load_elems[malloc_per_item]; - int local_idx[malloc_per_item]; - T bias[malloc_per_item]; - - int start_offset = sg_local_id * calc_per_item; - int local_num = calc_per_item; + T warp_best_val = static_cast(warp_best_val_tmp); + IdxT warp_best_idx = invalid_idx; - if (start_offset + local_num >= experts) { - local_num = experts - start_offset; - if (local_num < 0) { - local_num = 0; // No elements to process + if (local_best_pos != -1 && local_best_val == warp_best_val) { + warp_best_idx = local_best_idx; + } + warp_best_idx = + sycl::reduce_over_group(subgroup, warp_best_idx, sycl::minimum()); + + bool found = (warp_best_idx != invalid_idx); + if (found) { + int insert_pos = k; + bool still_shifting = true; +#pragma unroll + for (int shift = 0; shift < MaxReduceTopK - 1; ++shift) { + int prev_pos = k - shift - 1; + bool active = shift < k; + bool should_shift = active && still_shifting && + out_val[prev_pos] == warp_best_val && + out_idx[prev_pos] > warp_best_idx; + if (should_shift) { + out_val[prev_pos + 1] = out_val[prev_pos]; + out_idx[prev_pos + 1] = out_idx[prev_pos]; + insert_pos = prev_pos; + } + still_shifting = still_shifting && should_shift; } + out_val[insert_pos] = warp_best_val; + out_idx[insert_pos] = warp_best_idx; + } else { + out_val[k] = min_val; + out_idx[k] = 0; } - for (int e = 0; e < calc_per_item; ++e) { - load_elems[e] = kNegInfinity; - local_idx[e] = -1; - bias[e] = 0.0f; // Initialize bias to zero + if (found && local_best_pos != -1 && local_best_val == warp_best_val && + local_best_idx == warp_best_idx) { + selected[local_best_pos] = true; } + } +} + +template +inline void reduceTopK( + sycl::sub_group subgroup, + T* out_val, + IdxT* out_idx, + T val, + IdxT idx, + T min_val, + int topk) { + T in_vals[1] = {val}; + IdxT in_idxs[1] = {idx}; + reduceTopK<1>(subgroup, out_val, out_idx, in_vals, in_idxs, min_val, topk); +} - for (int e = 0; e < local_num; ++e) { - load_elems[e] = gating_output[tid * experts + start_offset + e]; +} // namespace reduce_topk + +template < + typename T, + typename BiasT, + typename IdxT, + ScoringFunc SF, + int MaxNumExperts, + bool MultiGroups, + int MaxNumTopExperts = DefaultMaxNumTopExperts> +SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel( + T* scores, + float* topkValues, + IdxT* topkIndices, + BiasT const* routingBias, + int64_t const numTokens, + int64_t const numGroup, + int64_t const topkGroup, + int64_t const topk, + int64_t const numExperts, + int64_t const numExpertsPerGroup, + bool const renormalize, + double const routedScalingFactor, + sycl::nd_item<1> item) { + constexpr float invalidScoreFloat = -std::numeric_limits::infinity(); + + int threadIdx = item.get_local_id(0); + int blockIdx = item.get_group(0); + if constexpr (MultiGroups) { + if (blockIdx >= numTokens) return; + } + int localSize = item.get_local_range(0); + bool has_bias = (routingBias != nullptr); + + int laneIdx = threadIdx % WARP_SIZE; + int warpIdx = threadIdx / WARP_SIZE; + + topkValues += blockIdx * topk; + topkIndices += blockIdx * topk; + + if constexpr (MultiGroups) { + auto subgroup = item.get_sub_group(); + T* scoresToken = scores + static_cast(blockIdx) * numExperts; + T selectedGroupScores[WARP_SIZE]; + int32_t selectedGroupIdx[WARP_SIZE]; + + T groupScore = neg_inf(); + if (laneIdx < numGroup) { + int32_t groupOffset = laneIdx * numExpertsPerGroup; + T largest = neg_inf(); + T secondLargest = neg_inf(); + + for (int32_t i = 0; i < numExpertsPerGroup; ++i) { + T value = apply_scoring(scoresToken[groupOffset + i]); + if (has_bias) { + value = value + sycl_cast(routingBias[groupOffset + i]); + } + if (value > largest) { + secondLargest = largest; + largest = value; + } else if (value > secondLargest) { + secondLargest = value; + } + } + groupScore = has_bias ? largest + secondLargest : largest; } - T local_elems[malloc_per_item]; + reduce_topk::reduceTopK( + subgroup, + selectedGroupScores, + selectedGroupIdx, + groupScore, + laneIdx, + neg_inf(), + static_cast(topkGroup)); + + bool proceed = false; - for (int e = 0; e < local_num; ++e) { - local_elems[e] = load_elems[e]; - local_idx[e] = start_offset + e; + if (topkGroup > 0) { + proceed = (selectedGroupScores[topkGroup - 1] != neg_inf()); } - if (scoring_mode == ScoringFunc::SOFTMAX) { - float softmax_max = kNegInfinity; - for (int e = 0; e < local_num; ++e) { - float s = load_elems[e]; - softmax_max = (softmax_max > s) ? softmax_max : s; - } - for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { - float other_val = sycl::permute_group_by_xor(sg, softmax_max, offset); - softmax_max = (softmax_max > other_val) ? softmax_max : other_val; - } - float softmax_sum = 0.0f; - for (int e = 0; e < local_num; ++e) { - float s = local_elems[e]; - softmax_sum += sycl::native::exp(s - softmax_max); - } - for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { - float other_val = sycl::permute_group_by_xor(sg, softmax_sum, offset); - softmax_sum += other_val; + if (!proceed) { + for (int i = laneIdx; i < topk; i += WARP_SIZE) { + topkIndices[i] = static_cast(i); + topkValues[i] = 1.0f / static_cast(topk); } - for (int e = 0; e < local_num; ++e) { - float s = local_elems[e]; - local_elems[e] = sycl::native::exp(s - softmax_max) / softmax_sum; - } - } else if (scoring_mode == ScoringFunc::SIGMOID) { - for (int e = 0; e < local_num; ++e) { - float s = load_elems[e]; - load_elems[e] = Sigmoid(s); - } - for (int e = 0; e < local_num; ++e) { - local_elems[e] = load_elems[e]; + return; + } + + constexpr int MaxExpertCandidatesPerLane = NumDeepseekExperts / WARP_SIZE; + T localCandidateScores[MaxExpertCandidatesPerLane]; + IdxT localCandidateIdx[MaxExpertCandidatesPerLane]; + T selectedExpertScores[DefaultMaxNumTopExperts]; + IdxT selectedExpertIdx[DefaultMaxNumTopExperts]; + + for (int i = 0; i < MaxExpertCandidatesPerLane; ++i) { + localCandidateScores[i] = neg_inf(); + localCandidateIdx[i] = 0; + } + + int32_t totalCandidates = topkGroup * numExpertsPerGroup; + for (int32_t candidate = laneIdx; candidate < totalCandidates; + candidate += WARP_SIZE) { + int32_t localSlot = candidate / WARP_SIZE; + int32_t selectedGroup = candidate / numExpertsPerGroup; + int32_t expertInGroup = candidate % numExpertsPerGroup; + int32_t gid = selectedGroupIdx[selectedGroup]; + int32_t idx = gid * numExpertsPerGroup + expertInGroup; + T candidateScore = neg_inf(); + + T input = scoresToken[idx]; + if (is_finite(input)) { + T score = apply_scoring(input); + candidateScore = score; + if (has_bias) { + candidateScore = + candidateScore + sycl_cast(routingBias[idx]); + } } + + localCandidateScores[localSlot] = candidateScore; + localCandidateIdx[localSlot] = static_cast(idx); } - bool has_bias = e_score_correction_bias != nullptr; - if (has_bias) { - for (int e = 0; e < local_num; ++e) { - bias[e] = e_score_correction_bias[start_offset + e]; + reduce_topk::reduceTopK( + subgroup, + selectedExpertScores, + selectedExpertIdx, + localCandidateScores, + localCandidateIdx, + neg_inf(), + static_cast(topk)); + + for (int i = 1; i < topk; ++i) { + T score = selectedExpertScores[i]; + IdxT idx = selectedExpertIdx[i]; + int j = i; + while (j > 0 && ((selectedExpertScores[j - 1] < score) || + ((selectedExpertScores[j - 1] == score) && + (selectedExpertIdx[j - 1] > idx)))) { + selectedExpertScores[j] = selectedExpertScores[j - 1]; + selectedExpertIdx[j] = selectedExpertIdx[j - 1]; + --j; } + selectedExpertScores[j] = score; + selectedExpertIdx[j] = idx; } - // perform topk_group groups - // 1 calculate each group scores - float group_scores[malloc_per_item * 2]; - for (int i = 0; i < num_expert_group * 2; ++i) { - group_scores[i] = kNegInfinity; + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = selectedExpertIdx[laneIdx]; + T in = scoresToken[static_cast(laneIdxOut)]; + laneUnbiased = xpu::to_float(apply_scoring(in)); } - for (int i = 0; i < local_num; ++i) { - float b = bias[i]; - float score = local_elems[i] + b; - int i_group = (calc_per_item * sg_local_id + i) / experts_per_group; - float group_max = group_scores[i_group]; - float group_next_max = group_scores[num_expert_group + i_group]; - if (score > group_max) { - group_next_max = group_max; - group_max = score; - } else if (score > group_next_max) { - group_next_max = score; + + float scale = static_cast(routedScalingFactor); + if (renormalize) { + // Match baseline precision: sum and divide in T precision + T laneScoreT = static_cast(laneUnbiased); + T topkSumT = static_cast(0); + for (int i = 0; i < static_cast(topk); ++i) { + T val = sycl::select_from_group(subgroup, laneScoreT, i); + topkSumT = topkSumT + val; + } + if (laneIdx < topk) { + T normalizedT = laneScoreT / topkSumT; + laneUnbiased = xpu::to_float(normalizedT); } - group_scores[i_group] = group_max; - group_scores[num_expert_group + i_group] = group_next_max; - } - for (int i = 0; i < num_expert_group; ++i) { - float group_max = group_scores[i]; - float group_next_max = group_scores[num_expert_group + i]; - - float max1 = sycl::reduce_over_group( - sg, sycl::max(group_max, group_next_max), sycl::maximum<>()); - float local_second = - (group_max < max1 && group_max > -INFINITY) ? group_max : -INFINITY; - local_second = (group_next_max < max1 && group_next_max > local_second) - ? group_next_max - : local_second; - float max2 = sycl::reduce_over_group(sg, local_second, sycl::maximum<>()); - group_scores[i] = max1 + (has_bias ? max2 : 0.0f); } - // 2 find topk_group groups as kNegInfinity - int group_topk_idx[malloc_per_item]; - for (int k = 0; k < topk_group; ++k) { - float k_max = group_scores[0]; - int k_max_idx = 0; - for (int e = 1; e < num_expert_group; ++e) { - float score = group_scores[e]; - - if (score > k_max) { - k_max = score; - k_max_idx = e; + if (laneIdx < topk) { + topkIndices[laneIdx] = laneIdxOut; + topkValues[laneIdx] = laneUnbiased * scale; + } + return; + } else { + // Single-group path: select top-k from all experts using single warp. + constexpr int ExpertsPerLane = MaxNumExperts / WARP_SIZE; + auto subgroup = item.get_sub_group(); + T* scoresToken = scores + static_cast(blockIdx) * numExperts; + + if constexpr (MaxNumExperts <= NumKimiK2Experts) { + // Direct path: <=384 experts, fits in registers + T localBiasedScores[ExpertsPerLane]; + IdxT localIdx[ExpertsPerLane]; + + for (int i = 0; i < ExpertsPerLane; ++i) { + int expertId = i * WARP_SIZE + laneIdx; + if (expertId < numExperts) { + T raw = scoresToken[expertId]; + T scored = apply_scoring(raw); + localBiasedScores[i] = + has_bias ? (scored + sycl_cast(routingBias[expertId])) + : scored; + } else { + localBiasedScores[i] = neg_inf(); } + localIdx[i] = static_cast(expertId); } - group_scores[k_max_idx] = kNegInfinity; - group_topk_idx[k] = k_max_idx; - } - // 3 mask no-topk_group groups - for (int i = 0; i < calc_per_item; ++i) { - bool is_masked = true; - for (int k = 0; k < topk_group; ++k) { - if ((local_idx[i] / experts_per_group) == group_topk_idx[k]) { - is_masked = false; - break; + T selectedScores[MaxNumTopExperts]; + IdxT selectedIdx[MaxNumTopExperts]; + reduce_topk::reduceTopK( + subgroup, + selectedScores, + selectedIdx, + localBiasedScores, + localIdx, + neg_inf(), + static_cast(topk)); + + float unbiasedArr[DefaultMaxNumTopExperts]; + for (int i = 0; i < static_cast(topk); ++i) { + IdxT idx_i = selectedIdx[i]; + T raw = scoresToken[static_cast(idx_i)]; + T scored = apply_scoring(raw); + unbiasedArr[i] = xpu::to_float(scored); + } + + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = selectedIdx[laneIdx]; + laneUnbiased = unbiasedArr[laneIdx]; + } + + float scale = static_cast(routedScalingFactor); + if (renormalize) { + float topk_sum = 0.0f; + for (int i = 0; i < static_cast(topk); ++i) { + topk_sum += unbiasedArr[i]; + } + if (laneIdx < topk) { + laneUnbiased /= topk_sum; } } - if (is_masked) { - local_elems[i] = kNegInfinity; + + if (laneIdx < topk) { + topkValues[laneIdx] = laneUnbiased * scale; + topkIndices[laneIdx] = laneIdxOut; + } + } else { + // Large expert path (>256): Each lane streams through its experts, + // maintaining a local top-k via insertion sort, then iteratively + // selects global top-k across the warp. Minimal register pressure. + constexpr int LocalTopK = DefaultMaxNumTopExperts; // 8 + auto subgroup = item.get_sub_group(); + T* scoresToken2 = scores + static_cast(blockIdx) * numExperts; + + // Each lane maintains its local top-k sorted descending + float laneTopVal[LocalTopK]; + IdxT laneTopIdx[LocalTopK]; + for (int k = 0; k < LocalTopK; ++k) { + laneTopVal[k] = -std::numeric_limits::infinity(); + laneTopIdx[k] = std::numeric_limits::max(); } - } - // Perform top-k selection - T topk_weights_local[malloc_per_item]; - int topk_ids_local[malloc_per_item]; - - for (int k = 0; k < top_k; ++k) { - float k_max = kNegInfinity; - int k_max_idx = -1; - int remove_ix = -1; - for (int e = 0; e < calc_per_item; ++e) { - float le = local_elems[e]; - float b = bias[e]; - float my_val = le + b; - int my_idx = local_idx[e]; - for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { - float other_val = sycl::permute_group_by_xor(sg, my_val, offset); - int other_idx = sycl::permute_group_by_xor(sg, my_idx, offset); - if (other_val > my_val || - (other_val == my_val && other_idx < my_idx)) { - my_val = other_val; - my_idx = other_idx; + // Stream through all experts assigned to this lane + for (int i = 0; i < ExpertsPerLane; ++i) { + int expertId = i * WARP_SIZE + laneIdx; + if (expertId >= numExperts) break; + + T raw = scoresToken2[expertId]; + T scored = apply_scoring(raw); + float val = + has_bias ? xpu::to_float( + scored + sycl_cast(routingBias[expertId])) + : xpu::to_float(scored); + IdxT idx = static_cast(expertId); + + // Check if this beats the worst in our local top-k + if (val > laneTopVal[LocalTopK - 1] || + (val == laneTopVal[LocalTopK - 1] && + idx < laneTopIdx[LocalTopK - 1])) { + // Insert into sorted position + int pos = LocalTopK - 1; + for (int j = LocalTopK - 2; j >= 0; --j) { + if (val > laneTopVal[j] || + (val == laneTopVal[j] && idx < laneTopIdx[j])) { + laneTopVal[j + 1] = laneTopVal[j]; + laneTopIdx[j + 1] = laneTopIdx[j]; + pos = j; + } else { + break; + } } - } - if (my_val > k_max || (my_val == k_max && my_idx < k_max_idx)) { - k_max = my_val; - k_max_idx = my_idx; - - if (k_max_idx == local_idx[e]) { - remove_ix = e; // Mark this index for removal - } else - remove_ix = -1; + laneTopVal[pos] = val; + laneTopIdx[pos] = idx; } } - int select_item = k_max_idx / calc_per_item; - int select_elem = k_max_idx % calc_per_item; - k_max = local_elems[select_elem]; - k_max = sycl::group_broadcast(sg, k_max, select_item); - if (remove_ix != -1) { - local_elems[remove_ix] = - kNegInfinity; // Reset the score to avoid re-selection - local_idx[remove_ix] = -1; - remove_ix = -1; + // Now iteratively select global top-k from the warp. + // Each lane offers its current best (top of local sorted list). + // The warp finds the global max, the winning lane pops it. + IdxT globalTopIdx[DefaultMaxNumTopExperts]; + int lanePtr = 0; // points to next candidate in laneTopVal + + for (int k = 0; k < static_cast(topk); ++k) { + float myVal = (lanePtr < LocalTopK) + ? laneTopVal[lanePtr] + : -std::numeric_limits::infinity(); + IdxT myIdx = (lanePtr < LocalTopK) ? laneTopIdx[lanePtr] + : std::numeric_limits::max(); + + // Find the best value across all lanes + float bestVal = + sycl::reduce_over_group(subgroup, myVal, sycl::maximum()); + + // Among lanes that have bestVal, pick smallest idx + IdxT candidateIdx = + (myVal == bestVal) ? myIdx : std::numeric_limits::max(); + IdxT bestIdx = sycl::reduce_over_group( + subgroup, candidateIdx, sycl::minimum()); + + globalTopIdx[k] = bestIdx; + + // The winning lane advances its pointer + if (myIdx == bestIdx && myVal == bestVal) { + lanePtr++; + } } - topk_weights_local[k] = k_max; - topk_ids_local[k] = k_max_idx < 0 ? k : k_max_idx; - } + // Re-read unbiased scores for winners + float unbiasedArr[DefaultMaxNumTopExperts]; + for (int i = 0; i < static_cast(topk); ++i) { + T raw = scoresToken2[static_cast(globalTopIdx[i])]; + T scored = apply_scoring(raw); + unbiasedArr[i] = xpu::to_float(scored); + } - if (renormalize) { - // Renormalize the top-k weights - float sum = 0; - for (int i = 0; i < top_k; ++i) { - sum += topk_weights_local[i]; + float laneUnbiased = 0.0f; + IdxT laneIdxOut = 0; + if (laneIdx < topk) { + laneIdxOut = globalTopIdx[laneIdx]; + laneUnbiased = unbiasedArr[laneIdx]; } - if (sum > 0) { - for (int i = 0; i < top_k; ++i) { - topk_weights_local[i] /= sum; + + float scale = static_cast(routedScalingFactor); + if (renormalize) { + float topk_sum = 0.0f; + for (int i = 0; i < static_cast(topk); ++i) { + topk_sum += unbiasedArr[i]; + } + if (laneIdx < topk) { + laneUnbiased /= topk_sum; } } - } - if (sg_local_id == 0) { - int offset = tid * top_k; - for (int i = 0; i < top_k; ++i) { - topk_weights[offset + i] = - topk_weights_local[i] * routed_scaling_factor; - if (!(topk_ids_local[i] >= 0 && topk_ids_local[i] < experts)) { - // Ensure valid index - topk_ids[offset + i] = 0; - continue; - } - topk_ids[offset + i] = topk_ids_local[i]; + if (laneIdx < topk) { + topkValues[laneIdx] = laneUnbiased * scale; + topkIndices[laneIdx] = laneIdxOut; } } - } - float* topk_weights; - int* topk_ids; - const T* gating_output; - const T* e_score_correction_bias; - const double routed_scaling_factor; - const ScoringFunc scoring_mode; - const bool renormalize; - const int tokens; - const int experts; - const int top_k; - const int num_expert_group; - const int topk_group; -}; - -template -void launch_fused_grouped_topk( - sycl::queue& queue, - float* topk_weights, - int* topk_ids, - const T* gating_output, - const T* e_score_correction_bias, - const double routed_scaling_factor, - const ScoringFunc scoring_mode, - const bool renormalize, - const int tokens, - const int experts, - const int top_k, - const int num_expert_group, - const int topk_group) { - using Kernel = Fused_Grouped_Topk; - auto range = Kernel::get_nd_range(tokens, experts); - - queue.submit([&](sycl::handler& cgh) { - Kernel task( - topk_weights, - topk_ids, - gating_output, - e_score_correction_bias, - routed_scaling_factor, - scoring_mode, - renormalize, - tokens, - experts, - top_k, - num_expert_group, - topk_group); - cgh.parallel_for(range, task); - }); + } // end if constexpr (!MultiGroups) } -template -void fused_grouped_topk( - float* topk_weights, - int* topk_ids, - const T* gating_output, - const T* e_score_correction_bias, - const double routed_scaling_factor, - const ScoringFunc scoring_mode, - const bool renormalize, - const int tokens, - const int experts, - const int top_k, - const int num_expert_group, - const int topk_group) { - auto& queue = vllm::xpu::vllmGetQueue(); - - TORCH_CHECK( - topk_group <= num_expert_group, - "topk_group must be less than or equal to num_expert_group"); - TORCH_CHECK( - experts % num_expert_group == 0, - "The number of experts (experts=", - experts, - ") must be divisible by num_expert_group (", - num_expert_group, - ")."); - - int max_expert_group = ((num_expert_group + 7) / 8) * 8; -#define CASE_TOPK(K) \ - case K: \ - launch_fused_grouped_topk( \ - queue, \ - topk_weights, \ - topk_ids, \ - gating_output, \ - e_score_correction_bias, \ - routed_scaling_factor, \ - scoring_mode, \ - renormalize, \ - tokens, \ - experts, \ - top_k, \ - num_expert_group, \ - topk_group); \ - break; - switch (max_expert_group) { - CASE_TOPK(8) - CASE_TOPK(16) - default: - TORCH_CHECK( - false, "error: not support num_expert_group=%d,\n", num_expert_group); +template +void invokeNoAuxTc( + T* scores, + float* topk_values, + IdxT* topk_indices, + BiasT const* bias, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double const routed_scaling_factor, + bool enable_pdl, + sycl::queue& queue) { + int64_t experts_per_group = num_experts / n_group; + bool is_single_group = + (n_group == 1) && (topk_group == 1) && + (num_experts <= MaxSupportedExpertCount) && + (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); + +#define LAUNCH_SMALL_KERNEL( \ + MAX_EXPERTS, MultiGroups, MAX_TOP_EXPERTS, NUM_THREADS) \ + do { \ + size_t local_size = static_cast(NUM_THREADS); \ + size_t global_size = static_cast(num_tokens) * local_size; \ + queue.submit([&](sycl::handler& cgh) { \ + cgh.parallel_for>( \ + sycl::nd_range<1>( \ + sycl::range<1>(global_size), sycl::range<1>(local_size)), \ + [=](sycl::nd_item<1> item) \ + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { \ + grouped_topk_fused_small_expert_count_kernel< \ + T, \ + BiasT, \ + IdxT, \ + SF, \ + MAX_EXPERTS, \ + MultiGroups, \ + MAX_TOP_EXPERTS>( \ + scores, \ + topk_values, \ + topk_indices, \ + bias, \ + num_tokens, \ + n_group, \ + topk_group, \ + topk, \ + num_experts, \ + experts_per_group, \ + renormalize, \ + routed_scaling_factor, \ + item); \ + }); \ + }); \ + } while (0) + + if (is_single_group) { + if (num_experts == NumNemotronExperts && n_group == 1 && + topk == MaxSupportedTopExperts) { + LAUNCH_SMALL_KERNEL( + NumNemotronExperts, false, MaxSupportedTopExperts, WARP_SIZE); + } else if ( + num_experts > NumKimiK2Experts && + num_experts <= MaxSupportedExpertCount) { + LAUNCH_SMALL_KERNEL( + MaxSupportedExpertCount, false, DefaultMaxNumTopExperts, WARP_SIZE); + } else if ( + num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts) { + LAUNCH_SMALL_KERNEL( + NumKimiK2Experts, false, DefaultMaxNumTopExperts, WARP_SIZE); + } else { + LAUNCH_SMALL_KERNEL( + MaxNumExpertsUnit, false, DefaultMaxNumTopExperts, WARP_SIZE); + } + } else { + LAUNCH_SMALL_KERNEL( + NumDeepseekExperts, true, DefaultMaxNumTopExperts, WARP_SIZE); } -#undef CASE_TOPK + +#undef LAUNCH_SMALL_KERNEL } -}; // namespace GroupedTopKImpl +#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \ + template void invokeNoAuxTc( \ + T * scores, \ + float* topk_values, \ + IdxT* topk_indices, \ + BiasT const* bias, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + bool enable_pdl, \ + sycl::queue& queue); + +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, + sycl::ext::oneapi::bfloat16, + int32_t, + SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::half, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC( + sycl::half, sycl::ext::oneapi::bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(sycl::ext::oneapi::bfloat16, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, sycl::half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC( + sycl::ext::oneapi::bfloat16, + sycl::ext::oneapi::bfloat16, + int32_t, + SCORING_NONE); +} // end namespace moe } // namespace vllm -/** - * @brief Perform grouped topk after sigmoid/addbias on gating_output. - * @param gating_output The gating output tensor of shape [n_tokens, n_experts]. - * @param n_topk The number of top experts to select. - * @param n_topk_group The number of top experts to select in the group. - * @return A tuple of tensors (topk_weights, topk_indices). - */ std::tuple fused_grouped_topk( - const torch::Tensor& hidden_states, - const torch::Tensor& gating_output, - const int64_t n_topk, - const bool renormalize, - const int64_t n_expert_group, - const int64_t n_topk_group, - const c10::string_view scoring_func, - const double routed_scaling_factor, - const c10::optional& bias) { - auto shape = gating_output.sizes().vec(); + torch::Tensor const& hidden_states, + torch::Tensor const& gating_output, + int64_t const n_topk, + bool const renormalize, + int64_t const n_expert_group, + int64_t const n_topk_group, + c10::string_view const scoring_func, + double const routed_scaling_factor, + c10::optional const& bias) { + auto data_type = gating_output.scalar_type(); + bool has_bias = bias.has_value() && bias->defined(); + auto bias_type = has_bias ? bias->scalar_type() : torch::kFloat32; + auto input_size = gating_output.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + int64_t n_group = n_expert_group; + int64_t topk_group = n_topk_group; + int64_t topk = n_topk; + TORCH_CHECK( hidden_states.sizes()[0] == gating_output.sizes()[0], - "Number of tokens mismatch") + "Number of tokens mismatch"); + TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor"); + TORCH_CHECK(n_group > 0, "n_group must be positive"); + TORCH_CHECK(topk > 0, "topk must be positive"); + TORCH_CHECK(topk_group > 0, "topk_group must be positive"); + TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group"); TORCH_CHECK( - shape.size() == 2, - "gating_output must be 2D tensor, but got ", - shape.size(), - "D"); - if (bias.has_value()) { - auto shape_bias = bias->sizes().vec(); - TORCH_CHECK( - shape_bias[0] == shape[1], - "gating_output and bias must has same innermost dimension, but got ", - shape, - " and ", - shape_bias); - } - int n_tokens = shape[0]; - int n_experts = shape[1]; - - vllm::GroupedTopKImpl::ScoringFunc scoring_mode; - if (scoring_func == "sigmoid") { - scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SIGMOID; - } else if (scoring_func == "softmax") { - scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SOFTMAX; - } else { - scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::DEFAULT; - } - - auto topk_weights = - torch::empty({n_tokens, n_topk}, at::dtype(at::kFloat).device(at::kXPU)); - auto topk_indices = - torch::empty({n_tokens, n_topk}, at::dtype(at::kInt).device(at::kXPU)); - - if (gating_output.scalar_type() == at::kBFloat16) { - using scalar_t = sycl::ext::oneapi::bfloat16; - vllm::GroupedTopKImpl::fused_grouped_topk( - reinterpret_cast(topk_weights.data_ptr()), - reinterpret_cast(topk_indices.data_ptr()), - reinterpret_cast(gating_output.data_ptr()), - bias.has_value() ? reinterpret_cast(bias->data_ptr()) - : nullptr, - routed_scaling_factor, - scoring_mode, - renormalize, - n_tokens, - n_experts, - n_topk, - n_expert_group, - n_topk_group); - } else if (gating_output.scalar_type() == at::kHalf) { - using scalar_t = sycl::half; - vllm::GroupedTopKImpl::fused_grouped_topk( - reinterpret_cast(topk_weights.data_ptr()), - reinterpret_cast(topk_indices.data_ptr()), - reinterpret_cast(gating_output.data_ptr()), - bias.has_value() ? reinterpret_cast(bias->data_ptr()) - : nullptr, - routed_scaling_factor, - scoring_mode, - renormalize, - n_tokens, - n_experts, - n_topk, - n_expert_group, - n_topk_group); - } else { - using scalar_t = float; - vllm::GroupedTopKImpl::fused_grouped_topk( - reinterpret_cast(topk_weights.data_ptr()), - reinterpret_cast(topk_indices.data_ptr()), - reinterpret_cast(gating_output.data_ptr()), - bias.has_value() ? reinterpret_cast(bias->data_ptr()) - : nullptr, - routed_scaling_factor, - scoring_mode, - renormalize, - n_tokens, - n_experts, - n_topk, - n_expert_group, - n_topk_group); + num_experts % n_group == 0, "num_experts should be divisible by n_group"); + TORCH_CHECK( + n_group <= 32, "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + TORCH_CHECK( + topk <= topk_group * (num_experts / n_group), + "topk must be <= topk_group * (num_experts / n_group)"); + TORCH_CHECK( + scoring_func == "sigmoid" || scoring_func == "softmax", + "Unsupported scoring_func: ", + scoring_func); + auto const sf = (scoring_func == "sigmoid") ? vllm::moe::SCORING_SIGMOID + : vllm::moe::SCORING_NONE; + + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, + torch::dtype(torch::kFloat32).device(gating_output.device())); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, + torch::dtype(torch::kInt32).device(gating_output.device())); + + // auto device_idx = gating_output.device().index(); + auto device = gating_output.device(); + auto& stream = vllm::xpu::vllmGetQueue(device.index()); + +#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \ + do { \ + switch (sf) { \ + case vllm::moe::SCORING_NONE: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(gating_output.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + (has_bias ? reinterpret_cast(bias->data_ptr()) \ + : nullptr), \ + num_tokens, \ + num_experts, \ + n_group, \ + topk_group, \ + topk, \ + renormalize, \ + routed_scaling_factor, \ + false, \ + stream); \ + break; \ + case vllm::moe::SCORING_SIGMOID: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(gating_output.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + (has_bias ? reinterpret_cast(bias->data_ptr()) \ + : nullptr), \ + num_tokens, \ + num_experts, \ + n_group, \ + topk_group, \ + topk, \ + renormalize, \ + routed_scaling_factor, \ + false, \ + stream); \ + break; \ + default: \ + throw std::invalid_argument("Unsupported scoring_func"); \ + break; \ + } \ + } while (0) + +#define LAUNCH_KERNEL(T, IdxT) \ + do { \ + switch (bias_type) { \ + case torch::kFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::half, IdxT); \ + break; \ + case torch::kFloat32: \ + LAUNCH_KERNEL_SF(T, float, IdxT); \ + break; \ + case torch::kBFloat16: \ + LAUNCH_KERNEL_SF(T, sycl::ext::oneapi::bfloat16, IdxT); \ + break; \ + default: \ + throw std::invalid_argument( \ + "Invalid bias dtype, only supports float16, float32, and " \ + "bfloat16"); \ + break; \ + } \ + } while (0) + + switch (data_type) { + case torch::kFloat16: + LAUNCH_KERNEL(sycl::half, int32_t); + break; + case torch::kFloat32: + LAUNCH_KERNEL(float, int32_t); + break; + case torch::kBFloat16: + LAUNCH_KERNEL(sycl::ext::oneapi::bfloat16, int32_t); + break; + default: + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; } - return std::make_tuple(topk_weights, topk_indices); +#undef LAUNCH_KERNEL +#undef LAUNCH_KERNEL_SF + return {topk_values, topk_indices}; } \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 47eb1d61d..fa3f4e07e 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -56,12 +56,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Fused Grouped TopK m.def( - "fused_grouped_topk(Tensor hidden_states, Tensor gating_output, int " - "n_topk, " - "bool renormalize, int n_expert_group, int n_topk_group, str " - "scoring_func, float routed_scaling_factor, Tensor? bias=None) -> " - "(Tensor, Tensor)"); + "fused_grouped_topk(" + " Tensor hidden_states," + " Tensor gating_output," + " int n_topk," + " bool renormalize," + " int n_expert_group," + " int n_topk_group," + " str scoring_func," + " float routed_scaling_factor," + " Tensor? bias=None" + ") -> (Tensor, Tensor)"); m.impl("fused_grouped_topk", torch::kXPU, &fused_grouped_topk); + // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " diff --git a/tests/ops/grouped_topk_op.py b/tests/ops/grouped_topk_op.py index 40fca3a98..71c698c5f 100644 --- a/tests/ops/grouped_topk_op.py +++ b/tests/ops/grouped_topk_op.py @@ -28,6 +28,7 @@ def grouped_topk( scores = gating_output.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased @@ -37,30 +38,38 @@ def grouped_topk( group_scores = (scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)) else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, - sorted=False)[1] # [n, top_k_group] + group_scores = (scores.view(num_token, num_expert_group, + -1).max(dim=-1).values) # [n, n_group] + # For batch invariance, use sorted=True to ensure + # deterministic expert selection + use_sorted = True + group_idx = torch.topk(group_scores, + k=topk_group, + dim=-1, + sorted=use_sorted)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( + score_mask = (group_mask.unsqueeze(-1).expand( num_token, num_expert_group, - scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] + scores.size(-1) // num_expert_group).reshape(num_token, -1)) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, - sorted=False) + sorted=use_sorted) + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights * routed_scaling_factor + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) @@ -87,8 +96,14 @@ def fused_grouped_topk( scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) topk_values, topk_indices = ops.grouped_topk( - scores, scores_with_bias.to(scores.dtype), num_expert_group, - topk_group, topk, renormalize, routed_scaling_factor) + scores, + scores_with_bias.to(scores.dtype), + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) return topk_values.to(torch.float32), topk_indices.to(torch.int32) @@ -103,7 +118,22 @@ def fused_grouped_topk_sycl( routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - return ops.fused_grouped_topk(hidden_states, gating_output, topk, - renormalize, num_expert_group, topk_group, - scoring_func, routed_scaling_factor, - e_score_correction_bias) + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + return ops.fused_grouped_topk( + hidden_states, + scores, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func, + routed_scaling_factor, + e_score_correction_bias, + ) diff --git a/tests/register_ops.py b/tests/register_ops.py index 0508ed06d..136ae8e22 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -458,7 +458,7 @@ def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, - topk: int, + n_topk: int, renormalize: bool, num_expert_group: int, topk_group: int, @@ -467,7 +467,7 @@ def fused_grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None, ): return torch.ops._moe_C.fused_grouped_topk(hidden_states, gating_output, - topk, renormalize, + n_topk, renormalize, num_expert_group, topk_group, scoring_func, routed_scaling_factor, diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py index ec383a8a0..82968dcb9 100644 --- a/tests/test_grouped_topk.py +++ b/tests/test_grouped_topk.py @@ -90,4 +90,4 @@ def test_grouped_topk(n_token: int, n_hidden: int, n_expert: int, topk: int, torch.testing.assert_close(baseline_topk_ids, test_topk_ids_sycl, atol=0, - rtol=0) + rtol=0) \ No newline at end of file