Skip to content

Commit 9a9a466

Browse files
committed
Harden DeformConv index-width guard and align mask test comment
1 parent 5f402ac commit 9a9a466

2 files changed

Lines changed: 28 additions & 12 deletions

File tree

onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ inline bool Needs64BitIndex(Values... values) {
5454
return ((static_cast<int64_t>(values) > kInt32Max) || ...);
5555
}
5656

57+
inline bool ProductExceedsInt32Max(std::initializer_list<int64_t> factors) {
58+
constexpr int64_t kInt32Max = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
59+
int64_t acc = 1;
60+
for (int64_t v : factors) {
61+
// DeformConv dimensions are expected to be non-negative after validation.
62+
// If violated unexpectedly, conservatively force the 64-bit kernel path.
63+
if (v < 0) return true;
64+
if (v == 0) return false;
65+
if (acc > kInt32Max / v) return true;
66+
acc *= v;
67+
}
68+
return false;
69+
}
70+
5771
// __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly.
5872
template <typename T>
5973
__device__ __inline__ T DeformConvLdg(const T* p) {
@@ -524,17 +538,19 @@ Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, i
524538
inline bool CheckDeformConvNeeds64BitIndex(
525539
int64_t num_kernels, int64_t C, int64_t H, int64_t W, int64_t kH, int64_t kW, int64_t out_h, int64_t out_w,
526540
int64_t parallel_imgs, int64_t offset_group) {
527-
const int64_t col_numel = static_cast<int64_t>(C) * kH * kW * parallel_imgs * out_h * out_w;
528-
const int64_t offset_inner_size = static_cast<int64_t>(2) * kH * kW * out_h * out_w;
529-
const int64_t mask_inner_size = kH * kW * out_h * out_w;
530-
const int64_t offset_numel = parallel_imgs * offset_group * offset_inner_size;
531-
const int64_t mask_numel = parallel_imgs * offset_group * mask_inner_size;
532-
const int64_t channel_hw = H * W;
533-
const int64_t batch_input_stride = C * channel_hw;
534-
const int64_t input_numel = parallel_imgs * batch_input_stride;
535-
536-
return Needs64BitIndex(num_kernels, col_numel, offset_inner_size, mask_inner_size, offset_numel, mask_numel,
537-
channel_hw, batch_input_stride, input_numel, offset_group);
541+
if (Needs64BitIndex(num_kernels, C, H, W, kH, kW, out_h, out_w, parallel_imgs, offset_group)) {
542+
return true;
543+
}
544+
545+
// Check potentially large products without evaluating intermediate multiplications.
546+
return ProductExceedsInt32Max({C, kH, kW, parallel_imgs, out_h, out_w}) || // col_numel
547+
ProductExceedsInt32Max({2, kH, kW, out_h, out_w}) || // offset_inner_size
548+
ProductExceedsInt32Max({kH, kW, out_h, out_w}) || // mask_inner_size
549+
ProductExceedsInt32Max({parallel_imgs, offset_group, 2, kH, kW, out_h, out_w}) || // offset_numel
550+
ProductExceedsInt32Max({parallel_imgs, offset_group, kH, kW, out_h, out_w}) || // mask_numel
551+
ProductExceedsInt32Max({H, W}) || // channel_hw
552+
ProductExceedsInt32Max({C, H, W}) || // batch_input_stride
553+
ProductExceedsInt32Max({parallel_imgs, C, H, W}); // input_numel
538554
}
539555

540556
template <typename T>

onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ TEST(DeformConvTest, Group1OffsetGroup2) {
860860
RunDeformConvTest<float>(p, X, W, offset, B, &mask, expected_Y);
861861
}
862862

863-
// Mask with zeros: exercises CUDA early-exit when mask_val == 0.
863+
// Mask with zeros: verifies zero mask suppresses sampled values (val * mask == 0).
864864
TEST(DeformConvTest, MaskWithZeros) {
865865
DeformConvTestParams p = {};
866866
p.batch_sz = 1;

0 commit comments

Comments
 (0)