@@ -58,6 +58,27 @@ inline float sigmoid_accurate(float x) {
5858 return 1 .f / (1 .f + sycl::native::exp (-x));
5959}
6060
61+
62+ template <typename T>
63+ inline T warp_reduce_max (sycl::sub_group sg, T val) {
64+ for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 ) {
65+ T other = sycl::select_from_group (sg, val,
66+ (sg.get_local_linear_id () ^ offset));
67+ val = sycl::max (val, other);
68+ }
69+ return val;
70+ }
71+
72+ template <typename T>
73+ inline T warp_reduce_min (sycl::sub_group sg, T val) {
74+ for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 ) {
75+ T other = sycl::select_from_group (sg, val,
76+ (sg.get_local_linear_id () ^ offset));
77+ val = sycl::min (val, other);
78+ }
79+ return val;
80+ }
81+
6182template <typename T>
6283inline T apply_sigmoid (T val) {
6384 float f = xpu::to_float (val);
@@ -109,17 +130,15 @@ inline void reduceTopK(
109130 }
110131 }
111132 float local_best_val_tmp = xpu::to_float (local_best_val);
112- float warp_best_val_tmp = sycl::reduce_over_group (
113- subgroup, local_best_val_tmp, sycl::maximum<float >());
133+ float warp_best_val_tmp = warp_reduce_max (subgroup, local_best_val_tmp);
114134
115135 T warp_best_val = static_cast <T>(warp_best_val_tmp);
116136 IdxT warp_best_idx = invalid_idx;
117137
118138 if (local_best_pos != -1 && local_best_val == warp_best_val) {
119139 warp_best_idx = local_best_idx;
120140 }
121- warp_best_idx =
122- sycl::reduce_over_group (subgroup, warp_best_idx, sycl::minimum<IdxT>());
141+ warp_best_idx = warp_reduce_min (subgroup, warp_best_idx);
123142
124143 bool found = (warp_best_idx != invalid_idx);
125144 if (found) {
@@ -473,14 +492,12 @@ SYCL_EXTERNAL inline void grouped_topk_fused_small_expert_count_kernel(
473492 : std::numeric_limits<IdxT>::max ();
474493
475494 // Find the best value across all lanes
476- float bestVal =
477- sycl::reduce_over_group (subgroup, myVal, sycl::maximum<float >());
495+ float bestVal = warp_reduce_max (subgroup, myVal);
478496
479497 // Among lanes that have bestVal, pick smallest idx
480498 IdxT candidateIdx =
481499 (myVal == bestVal) ? myIdx : std::numeric_limits<IdxT>::max ();
482- IdxT bestIdx = sycl::reduce_over_group (
483- subgroup, candidateIdx, sycl::minimum<IdxT>());
500+ IdxT bestIdx = warp_reduce_min (subgroup, candidateIdx);
484501
485502 globalTopIdx[k] = bestIdx;
486503
@@ -537,8 +554,8 @@ void invokeNoAuxTc(
537554 int64_t const topk,
538555 bool const renormalize,
539556 double const routed_scaling_factor,
540- bool enable_pdl = false ,
541- sycl::queue queue = sycl::queue() ) {
557+ bool enable_pdl,
558+ sycl::queue& queue) {
542559 int64_t experts_per_group = num_experts / n_group;
543560 bool is_single_group =
544561 (n_group == 1 ) && (topk_group == 1 ) &&
@@ -628,7 +645,7 @@ void invokeNoAuxTc(
628645 bool const renormalize, \
629646 double const routed_scaling_factor, \
630647 bool enable_pdl, \
631- sycl::queue queue);
648+ sycl::queue& queue);
632649
633650INSTANTIATE_NOAUX_TC (float , float , int32_t , SCORING_SIGMOID);
634651INSTANTIATE_NOAUX_TC (float , sycl::half, int32_t , SCORING_SIGMOID);
@@ -715,8 +732,9 @@ std::tuple<torch::Tensor, torch::Tensor> fused_grouped_topk(
715732 {num_tokens, topk},
716733 torch::dtype (torch::kInt32 ).device (gating_output.device ()));
717734
718- auto device_idx = gating_output.device ().index ();
719- auto stream = c10::xpu::getCurrentXPUStream (device_idx).queue ();
735+ // auto device_idx = gating_output.device().index();
736+ auto device = gating_output.device ();
737+ auto & stream = vllm::xpu::vllmGetQueue (device.index ());
720738
721739#define LAUNCH_KERNEL_SF (T, BiasT, IdxT ) \
722740 do { \
0 commit comments