Skip to content

Commit 3623db2

Browse files
committed
Improve comments and code styles
1 parent 9b265c3 commit 3623db2

1 file changed

Lines changed: 32 additions & 24 deletions

File tree

onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,15 @@ __device__ __inline__ T BilinearInterpolate(
148148
using Traits = DeformConvBilinearTraits<T>;
149149
using CoordT = typename Traits::ComputeT;
150150

151-
// [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case).
151+
// [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() and neighbor loads for OOB case).
152+
// Semantics guardrail: if sample point is outside [-1, H) x [-1, W), ONNX bilinear contribution is exactly 0.
153+
// Why keep this even with branchless masked loads below:
154+
// - The branchless path guarantees safe addressing and correct masked zero, but still pays floor/weight math
155+
// and four global loads.
156+
// - This early return avoids all of that work for clearly OOB samples.
157+
// About divergence: mixed in/out-of-bound warps can diverge here, but OOB lanes terminate immediately while
158+
// in-bound lanes continue useful work; in practice this often wins unless OOB distribution is highly random
159+
// and branch hit-rate is very high.
152160
if (h <= static_cast<CoordT>(-1) || h >= height || w <= static_cast<CoordT>(-1) || w >= width) {
153161
return Traits::Zero();
154162
}
@@ -347,29 +355,26 @@ __global__ void DeformableIm2ColKernel(
347355
}
348356
T* data_col_ptr = data_col_ptr_base + row_kernel_base * col_stride;
349357

358+
auto step_kernel_point = [&]() {
359+
process_kernel_point(offset_h_ptr, offset_w_ptr, mask_ptr, data_col_ptr, h_base, w_base);
360+
offset_h_ptr += offset_pair_stride;
361+
offset_w_ptr += offset_pair_stride;
362+
if constexpr (UseMask) {
363+
mask_ptr += out_size;
364+
}
365+
data_col_ptr += col_stride;
366+
w_base += static_cast<CoordT>(dilation_w);
367+
};
368+
350369
// Small fixed kernels: unroll inner j so codegen matches the old fully-unrolled 1x1/3x3 path.
351370
if constexpr (is_fixed && kH * kW <= 9) {
352371
#pragma unroll
353372
for (IndexT j = 0; j < row_width; ++j) {
354-
process_kernel_point(offset_h_ptr, offset_w_ptr, mask_ptr, data_col_ptr, h_base, w_base);
355-
offset_h_ptr += offset_pair_stride;
356-
offset_w_ptr += offset_pair_stride;
357-
if constexpr (UseMask) {
358-
mask_ptr += out_size;
359-
}
360-
data_col_ptr += col_stride;
361-
w_base += static_cast<CoordT>(dilation_w);
373+
step_kernel_point();
362374
}
363375
} else {
364376
for (IndexT j = 0; j < row_width; ++j) {
365-
process_kernel_point(offset_h_ptr, offset_w_ptr, mask_ptr, data_col_ptr, h_base, w_base);
366-
offset_h_ptr += offset_pair_stride;
367-
offset_w_ptr += offset_pair_stride;
368-
if constexpr (UseMask) {
369-
mask_ptr += out_size;
370-
}
371-
data_col_ptr += col_stride;
372-
w_base += static_cast<CoordT>(dilation_w);
377+
step_kernel_point();
373378
}
374379
}
375380
};
@@ -586,22 +591,25 @@ Status DeformConvIm2ColImpl(
586591
}
587592
};
588593

589-
auto launch_with_mask = [&](auto kH_tag, auto kW_tag) {
594+
auto launch_with_mask = [&](auto k_size_tag) {
590595
if (use_mask) {
591-
launch(kH_tag, kW_tag, std::integral_constant<bool, true>{});
596+
launch(k_size_tag, k_size_tag, std::integral_constant<bool, true>{});
592597
} else {
593-
launch(kH_tag, kW_tag, std::integral_constant<bool, false>{});
598+
launch(k_size_tag, k_size_tag, std::integral_constant<bool, false>{});
594599
}
595600
};
596601

602+
// Keep template specializations for the most common kernel sizes in modern models.
603+
// 5x5 is intentionally not specialized: it is less common in current architectures and is often
604+
// replaced by stacked 3x3 blocks (similar receptive field with better optimization flexibility).
597605
if (kH == 1 && kW == 1) {
598-
launch_with_mask(DeformConvKSize<1>{}, DeformConvKSize<1>{});
606+
launch_with_mask(DeformConvKSize<1>{});
599607
} else if (kH == 3 && kW == 3) {
600-
launch_with_mask(DeformConvKSize<3>{}, DeformConvKSize<3>{});
608+
launch_with_mask(DeformConvKSize<3>{});
601609
} else if (kH == 7 && kW == 7) {
602-
launch_with_mask(DeformConvKSize<7>{}, DeformConvKSize<7>{});
610+
launch_with_mask(DeformConvKSize<7>{});
603611
} else {
604-
launch_with_mask(DeformConvKSize<-1>{}, DeformConvKSize<-1>{});
612+
launch_with_mask(DeformConvKSize<-1>{});
605613
}
606614
return CUDA_CALL(cudaGetLastError());
607615
}

0 commit comments

Comments
 (0)