SortingKernels: use int64_t type for num_tile#3555
Conversation
In order to avoid num_tile overflow it should be declared as int64_t type.
There was a problem hiding this comment.
Pull request overview
Updates the XPU SYCL sorting kernel to compute num_tiles in 64-bit to avoid overflow during the ceil-division calculation for large num_elements.
Changes:
- Switches
num_tilesinsegmented_radix_sort_pairs_kernelfrominttoint64_t. - Performs the
num_elements + TILE_PROCESSING_LENGTH - 1arithmetic in 64-bit viastatic_cast<int64_t>(num_elements).
Skill files read:
.github/skills/xpu-ops-pr-review/SKILL.md.github/skills/xpu-ops-pr-review/references/torch-xpu-ops-review-notes.md.github/skills/xpu-ops-pr-review/references/review-checklist.md.github/skills/xpu-ops-pr-review/references/bc-guidelines.md
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated no new comments.
Comments suppressed due to low confidence (2)
src/ATen/native/xpu/sycl/SortingKernels.h:194
sycl_kernel_submit(num_segments * num_tiles * GROUP_SIZE, ...)multipliesintvalues first, so the product can overflow in 32-bit signed arithmetic before being passed to theint64_toverload. Please promote toint64_t(and ideally check for overflow) when computing the global range.
This issue also appears on line 328 of the same file.
auto caller = SegmentedRadixSortPairsUpsweepFunctor<method_t, key_t, value_t>(
keys_in, counts, num_elements, begin_bit, end_bit);
sycl_kernel_submit(
num_segments * num_tiles * GROUP_SIZE,
GROUP_SIZE,
at::xpu::getCurrentSYCLQueue(),
caller);
src/ATen/native/xpu/sycl/SortingKernels.h:342
sycl_kernel_submit(num_segments * num_tiles * GROUP_SIZE, ...)is still computed inintarithmetic here, which can overflow before conversion to theint64_tparameter type. Compute the global range usingint64_t(and consider an overflow check) to avoid launching an incorrect ND-range.
auto caller =
SegmentedRadixSortPairsDownsweepFunctor<method_t, key_t, value_t>(
keys_in,
keys_out,
values_in,
values_out,
num_elements,
begin_bit,
end_bit,
count);
sycl_kernel_submit(
num_segments * num_tiles * GROUP_SIZE,
GROUP_SIZE,
at::xpu::getCurrentSYCLQueue(),
caller);
BBBela
left a comment
There was a problem hiding this comment.
Looks good to me.
Thank you! 😉
Directly no. It was taken from #3426 |
Please add a test case in torch-xpu-ops referring to |
In order to avoid num_tile overflow it should be
declared as int64_t type.