@@ -148,7 +148,15 @@ __device__ __inline__ T BilinearInterpolate(
148148 using Traits = DeformConvBilinearTraits<T>;
149149 using CoordT = typename Traits::ComputeT;
150150
151- // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case).
151+ // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() and neighbor loads for OOB case).
152+ // Semantics guardrail: if sample point is outside [-1, H) x [-1, W), ONNX bilinear contribution is exactly 0.
153+ // Why keep this even with branchless masked loads below:
154+ // - The branchless path guarantees safe addressing and correct masked zero, but still pays floor/weight math
155+ // and four global loads.
156+ // - This early return avoids all of that work for clearly OOB samples.
157+ // About divergence: mixed in/out-of-bound warps can diverge here, but OOB lanes terminate immediately while
158+ // in-bound lanes continue useful work; in practice this often wins unless OOB distribution is highly random
159+ // and branch hit-rate is very high.
152160 if (h <= static_cast <CoordT>(-1 ) || h >= height || w <= static_cast <CoordT>(-1 ) || w >= width) {
153161 return Traits::Zero ();
154162 }
@@ -347,29 +355,26 @@ __global__ void DeformableIm2ColKernel(
347355 }
348356 T* data_col_ptr = data_col_ptr_base + row_kernel_base * col_stride;
349357
358+ auto step_kernel_point = [&]() {
359+ process_kernel_point (offset_h_ptr, offset_w_ptr, mask_ptr, data_col_ptr, h_base, w_base);
360+ offset_h_ptr += offset_pair_stride;
361+ offset_w_ptr += offset_pair_stride;
362+ if constexpr (UseMask) {
363+ mask_ptr += out_size;
364+ }
365+ data_col_ptr += col_stride;
366+ w_base += static_cast <CoordT>(dilation_w);
367+ };
368+
350369 // Small fixed kernels: unroll inner j so codegen matches the old fully-unrolled 1x1/3x3 path.
351370 if constexpr (is_fixed && kH * kW <= 9 ) {
352371#pragma unroll
353372 for (IndexT j = 0 ; j < row_width; ++j) {
354- process_kernel_point (offset_h_ptr, offset_w_ptr, mask_ptr, data_col_ptr, h_base, w_base);
355- offset_h_ptr += offset_pair_stride;
356- offset_w_ptr += offset_pair_stride;
357- if constexpr (UseMask) {
358- mask_ptr += out_size;
359- }
360- data_col_ptr += col_stride;
361- w_base += static_cast <CoordT>(dilation_w);
373+ step_kernel_point ();
362374 }
363375 } else {
364376 for (IndexT j = 0 ; j < row_width; ++j) {
365- process_kernel_point (offset_h_ptr, offset_w_ptr, mask_ptr, data_col_ptr, h_base, w_base);
366- offset_h_ptr += offset_pair_stride;
367- offset_w_ptr += offset_pair_stride;
368- if constexpr (UseMask) {
369- mask_ptr += out_size;
370- }
371- data_col_ptr += col_stride;
372- w_base += static_cast <CoordT>(dilation_w);
377+ step_kernel_point ();
373378 }
374379 }
375380 };
@@ -586,22 +591,25 @@ Status DeformConvIm2ColImpl(
586591 }
587592 };
588593
589- auto launch_with_mask = [&](auto kH_tag , auto kW_tag ) {
594+ auto launch_with_mask = [&](auto k_size_tag ) {
590595 if (use_mask) {
591- launch (kH_tag , kW_tag , std::integral_constant<bool , true >{});
596+ launch (k_size_tag, k_size_tag , std::integral_constant<bool , true >{});
592597 } else {
593- launch (kH_tag , kW_tag , std::integral_constant<bool , false >{});
598+ launch (k_size_tag, k_size_tag , std::integral_constant<bool , false >{});
594599 }
595600 };
596601
602+ // Keep template specializations for the most common kernel sizes in modern models.
603+ // 5x5 is intentionally not specialized: it is less common in current architectures and is often
604+ // replaced by stacked 3x3 blocks (similar receptive field with better optimization flexibility).
597605 if (kH == 1 && kW == 1 ) {
598- launch_with_mask (DeformConvKSize<1 >{}, DeformConvKSize< 1 >{} );
606+ launch_with_mask (DeformConvKSize<1 >{});
599607 } else if (kH == 3 && kW == 3 ) {
600- launch_with_mask (DeformConvKSize<3 >{}, DeformConvKSize< 3 >{} );
608+ launch_with_mask (DeformConvKSize<3 >{});
601609 } else if (kH == 7 && kW == 7 ) {
602- launch_with_mask (DeformConvKSize<7 >{}, DeformConvKSize< 7 >{} );
610+ launch_with_mask (DeformConvKSize<7 >{});
603611 } else {
604- launch_with_mask (DeformConvKSize<-1 >{}, DeformConvKSize<- 1 >{} );
612+ launch_with_mask (DeformConvKSize<-1 >{});
605613 }
606614 return CUDA_CALL (cudaGetLastError ());
607615}
0 commit comments