Skip to content

Commit d72b45f

Browse files
committed
Make deform conv bilinear sampling branchless with masked safe loads
1 parent 9b060ac commit d72b45f

1 file changed

Lines changed: 32 additions & 32 deletions

File tree

onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,15 @@ struct DeformConvBilinearTraits<BFloat16> {
129129
// cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT
130130
// round trips when computing lh/lw/hh/hw.
131131
//
132-
// Empirical share of samples that take the guarded "edge" branch (else below) vs fully interior
133-
// fast-path, from one workload (counts = edge samples / total bilinear samples).
132+
// Historical note: before switching to branchless masked loads, this workload had the following
133+
// "edge sample" ratio (counts = samples with >=1 OOB neighbor / total bilinear samples).
134+
// The numbers remain useful as boundary-hit context, but no longer imply control-flow divergence.
134135
// Example workload only; not a benchmark or representative ratio.
135136
// kernel 1x1: 1.3746% (2421 / 176128)
136137
// kernel 3x3: 1.4833% (11756 / 792576)
137138
// kernel 7x7: 4.7593% (52537 / 1103872)
138-
// Offsets tend to be spatially smooth, so warps often agree on fast-path vs edge-path, which
139-
// limits divergence versus always doing four per-neighbor conditional loads.
139+
// Current implementation always issues safe-address loads and masks invalid neighbors to zero.
140+
// Offsets are often spatially smooth, so nearby threads still tend to exhibit similar validity patterns.
140141
template <typename T>
141142
__device__ __inline__ T BilinearInterpolate(
142143
const T* in,
@@ -165,34 +166,33 @@ __device__ __inline__ T BilinearInterpolate(
165166
CoordT hh = static_cast<CoordT>(1) - lh;
166167
CoordT hw = static_cast<CoordT>(1) - lw;
167168

168-
// [Optimization 3]: Avoid a second multiply for base_high.
169-
// Original code computed both bases as:
170-
// base_low = h_low * width;
171-
// base_high = h_high * width;
172-
// Since h_high = h_low + 1, we can rewrite base_high as base_low + width and
173-
// save one integer multiply in the hot path:
174-
// base_low = h_low * width;
175-
// base_high = base_low + width;
176-
int base_low = h_low * width;
177-
int base_high = base_low + width;
178-
179-
CoordT v1, v2, v3, v4;
180-
181-
// [Optimization 4]: Interior fast-path when all four bilinear neighbors lie inside the image.
182-
// In that case we skip four per-neighbor ternary checks (see else). See block comment above for
183-
// measured edge-branch rates (~1.4%–4.8% in sampled configs); the majority of threads hit this branch.
184-
if (h_low >= 0 && w_low >= 0 && h_high < height && w_high < width) {
185-
v1 = Traits::Load(in + base_low + w_low);
186-
v2 = Traits::Load(in + base_low + w_high);
187-
v3 = Traits::Load(in + base_high + w_low);
188-
v4 = Traits::Load(in + base_high + w_high);
189-
} else {
190-
// Edge / partial OOB: same semantics as before (invalid neighbor contributes 0).
191-
v1 = (h_low >= 0 && w_low >= 0) ? Traits::Load(in + base_low + w_low) : static_cast<CoordT>(0);
192-
v2 = (h_low >= 0 && w_high < width) ? Traits::Load(in + base_low + w_high) : static_cast<CoordT>(0);
193-
v3 = (h_high < height && w_low >= 0) ? Traits::Load(in + base_high + w_low) : static_cast<CoordT>(0);
194-
v4 = (h_high < height && w_high < width) ? Traits::Load(in + base_high + w_high) : static_cast<CoordT>(0);
195-
}
169+
// [Optimization 4]: Branchless neighbor loads via "safe address + validity mask".
170+
// 1) Clamp each coordinate to a legal address first (prevents illegal memory access).
171+
// 2) Compute validity predicates for the true (possibly OOB) coordinates.
172+
// 3) Always load from clamped address and mask invalid neighbors to zero.
173+
// Modern CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
174+
const int safe_h_low = max(0, min(h_low, height - 1));
175+
const int safe_h_high = max(0, min(h_high, height - 1));
176+
const int safe_w_low = max(0, min(w_low, width - 1));
177+
const int safe_w_high = max(0, min(w_high, width - 1));
178+
179+
const int safe_base_low = safe_h_low * width;
180+
const int safe_base_high = safe_h_high * width;
181+
182+
const bool h_low_valid = (h_low >= 0 && h_low < height);
183+
const bool h_high_valid = (h_high >= 0 && h_high < height);
184+
const bool w_low_valid = (w_low >= 0 && w_low < width);
185+
const bool w_high_valid = (w_high >= 0 && w_high < width);
186+
187+
const CoordT m1 = static_cast<CoordT>(h_low_valid && w_low_valid);
188+
const CoordT m2 = static_cast<CoordT>(h_low_valid && w_high_valid);
189+
const CoordT m3 = static_cast<CoordT>(h_high_valid && w_low_valid);
190+
const CoordT m4 = static_cast<CoordT>(h_high_valid && w_high_valid);
191+
192+
const CoordT v1 = Traits::Load(in + safe_base_low + safe_w_low) * m1;
193+
const CoordT v2 = Traits::Load(in + safe_base_low + safe_w_high) * m2;
194+
const CoordT v3 = Traits::Load(in + safe_base_high + safe_w_low) * m3;
195+
const CoordT v4 = Traits::Load(in + safe_base_high + safe_w_high) * m4;
196196

197197
// [Optimization 5]: Factor bilinear into horizontal blends on two rows, then vertical blend.
198198
// Algebraically equivalent to w1*v1 + w2*v2 + w3*v3 + w4*v4 with w1..w4 from hh/hw/lh/lw;

0 commit comments

Comments
 (0)