@@ -54,6 +54,20 @@ inline bool Needs64BitIndex(Values... values) {
5454 return ((static_cast <int64_t >(values) > kInt32Max ) || ...);
5555}
5656
57+ inline bool ProductExceedsInt32Max (std::initializer_list<int64_t > factors) {
58+ constexpr int64_t kInt32Max = static_cast <int64_t >(std::numeric_limits<int32_t >::max ());
59+ int64_t acc = 1 ;
60+ for (int64_t v : factors) {
61+ // DeformConv dimensions are expected to be non-negative after validation.
62+ // If violated unexpectedly, conservatively force the 64-bit kernel path.
63+ if (v < 0 ) return true ;
64+ if (v == 0 ) return false ;
65+ if (acc > kInt32Max / v) return true ;
66+ acc *= v;
67+ }
68+ return false ;
69+ }
70+
5771// __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly.
5872template <typename T>
5973__device__ __inline__ T DeformConvLdg (const T* p) {
@@ -524,17 +538,19 @@ Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, i
524538inline bool CheckDeformConvNeeds64BitIndex (
525539 int64_t num_kernels, int64_t C, int64_t H, int64_t W, int64_t kH , int64_t kW , int64_t out_h, int64_t out_w,
526540 int64_t parallel_imgs, int64_t offset_group) {
527- const int64_t col_numel = static_cast <int64_t >(C) * kH * kW * parallel_imgs * out_h * out_w;
528- const int64_t offset_inner_size = static_cast <int64_t >(2 ) * kH * kW * out_h * out_w;
529- const int64_t mask_inner_size = kH * kW * out_h * out_w;
530- const int64_t offset_numel = parallel_imgs * offset_group * offset_inner_size;
531- const int64_t mask_numel = parallel_imgs * offset_group * mask_inner_size;
532- const int64_t channel_hw = H * W;
533- const int64_t batch_input_stride = C * channel_hw;
534- const int64_t input_numel = parallel_imgs * batch_input_stride;
535-
536- return Needs64BitIndex (num_kernels, col_numel, offset_inner_size, mask_inner_size, offset_numel, mask_numel,
537- channel_hw, batch_input_stride, input_numel, offset_group);
541+ if (Needs64BitIndex (num_kernels, C, H, W, kH , kW , out_h, out_w, parallel_imgs, offset_group)) {
542+ return true ;
543+ }
544+
545+ // Check potentially large products without evaluating intermediate multiplications.
546+ return ProductExceedsInt32Max ({C, kH , kW , parallel_imgs, out_h, out_w}) || // col_numel
547+ ProductExceedsInt32Max ({2 , kH , kW , out_h, out_w}) || // offset_inner_size
548+ ProductExceedsInt32Max ({kH , kW , out_h, out_w}) || // mask_inner_size
549+ ProductExceedsInt32Max ({parallel_imgs, offset_group, 2 , kH , kW , out_h, out_w}) || // offset_numel
550+ ProductExceedsInt32Max ({parallel_imgs, offset_group, kH , kW , out_h, out_w}) || // mask_numel
551+ ProductExceedsInt32Max ({H, W}) || // channel_hw
552+ ProductExceedsInt32Max ({C, H, W}) || // batch_input_stride
553+ ProductExceedsInt32Max ({parallel_imgs, C, H, W}); // input_numel
538554}
539555
540556template <typename T>
0 commit comments