Skip to content

Commit 31779fa

Browse files
change kernel launch method
Signed-off-by: root <xiaolong.guo@intel.com>
1 parent c88c871 commit 31779fa

1 file changed

Lines changed: 31 additions & 13 deletions

File tree

csrc/moe/fused_grouped_topk.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,27 @@ inline float sigmoid_accurate(float x) {
5858
return 1.f / (1.f + sycl::native::exp(-x));
5959
}
6060

61+
62+
template <typename T>
63+
inline T warp_reduce_max(sycl::sub_group sg, T val) {
64+
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
65+
T other = sycl::select_from_group(sg, val,
66+
(sg.get_local_linear_id() ^ offset));
67+
val = sycl::max(val, other);
68+
}
69+
return val;
70+
}
71+
72+
template <typename T>
73+
inline T warp_reduce_min(sycl::sub_group sg, T val) {
74+
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
75+
T other = sycl::select_from_group(sg, val,
76+
(sg.get_local_linear_id() ^ offset));
77+
val = sycl::min(val, other);
78+
}
79+
return val;
80+
}
81+
6182
template <typename T>
6283
inline T apply_sigmoid(T val) {
6384
float f = xpu::to_float(val);
@@ -109,17 +130,15 @@ inline void reduceTopK(
109130
}
110131
}
111132
float local_best_val_tmp = xpu::to_float(local_best_val);
112-
float warp_best_val_tmp = sycl::reduce_over_group(
113-
subgroup, local_best_val_tmp, sycl::maximum<float>());
133+
float warp_best_val_tmp = warp_reduce_max(subgroup, local_best_val_tmp);
114134

115135
T warp_best_val = static_cast<T>(warp_best_val_tmp);
116136
IdxT warp_best_idx = invalid_idx;
117137

118138
if (local_best_pos != -1 && local_best_val == warp_best_val) {
119139
warp_best_idx = local_best_idx;
120140
}
121-
warp_best_idx =
122-
sycl::reduce_over_group(subgroup, warp_best_idx, sycl::minimum<IdxT>());
141+
warp_best_idx = warp_reduce_min(subgroup, warp_best_idx);
123142

124143
bool found = (warp_best_idx != invalid_idx);
125144
if (found) {
@@ -473,14 +492,12 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel(
473492
: std::numeric_limits<IdxT>::max();
474493

475494
// Find the best value across all lanes
476-
float bestVal =
477-
sycl::reduce_over_group(subgroup, myVal, sycl::maximum<float>());
495+
float bestVal = warp_reduce_max(subgroup, myVal);
478496

479497
// Among lanes that have bestVal, pick smallest idx
480498
IdxT candidateIdx =
481499
(myVal == bestVal) ? myIdx : std::numeric_limits<IdxT>::max();
482-
IdxT bestIdx = sycl::reduce_over_group(
483-
subgroup, candidateIdx, sycl::minimum<IdxT>());
500+
IdxT bestIdx = warp_reduce_min(subgroup, candidateIdx);
484501

485502
globalTopIdx[k] = bestIdx;
486503

@@ -537,8 +554,8 @@ void invokeNoAuxTc(
537554
int64_t const topk,
538555
bool const renormalize,
539556
double const routed_scaling_factor,
540-
bool enable_pdl = false,
541-
sycl::queue queue = sycl::queue()) {
557+
bool enable_pdl,
558+
sycl::queue& queue) {
542559
int64_t experts_per_group = num_experts / n_group;
543560
bool is_single_group =
544561
(n_group == 1) && (topk_group == 1) &&
@@ -628,7 +645,7 @@ void invokeNoAuxTc(
628645
bool const renormalize, \
629646
double const routed_scaling_factor, \
630647
bool enable_pdl, \
631-
sycl::queue queue);
648+
sycl::queue& queue);
632649

633650
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID);
634651
INSTANTIATE_NOAUX_TC(float, sycl::half, int32_t, SCORING_SIGMOID);
@@ -715,8 +732,9 @@ std::tuple<torch::Tensor, torch::Tensor> fused_grouped_topk(
715732
{num_tokens, topk},
716733
torch::dtype(torch::kInt32).device(gating_output.device()));
717734

718-
auto device_idx = gating_output.device().index();
719-
auto stream = c10::xpu::getCurrentXPUStream(device_idx).queue();
735+
// auto device_idx = gating_output.device().index();
736+
auto device = gating_output.device();
737+
auto& stream = vllm::xpu::vllmGetQueue(device.index());
720738

721739
#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \
722740
do { \

0 commit comments

Comments
 (0)