@@ -129,14 +129,15 @@ struct DeformConvBilinearTraits<BFloat16> {
129129// cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT
130130// round trips when computing lh/lw/hh/hw.
131131//
132- // Empirical share of samples that take the guarded "edge" branch (else below) vs fully interior
133- // fast-path, from one workload (counts = edge samples / total bilinear samples).
132+ // Historical note: before switching to branchless masked loads, this workload had the following
133+ // "edge sample" ratio (counts = samples with >=1 OOB neighbor / total bilinear samples).
134+ // The numbers remain useful as boundary-hit context, but no longer imply control-flow divergence.
134135// Example workload only; not a benchmark or representative ratio.
135136// kernel 1x1: 1.3746% (2421 / 176128)
136137// kernel 3x3: 1.4833% (11756 / 792576)
137138// kernel 7x7: 4.7593% (52537 / 1103872)
138- // Offsets tend to be spatially smooth, so warps often agree on fast-path vs edge-path, which
139- // limits divergence versus always doing four per-neighbor conditional loads .
139+ // Current implementation always issues safe-address loads and masks invalid neighbors to zero.
140+ // Offsets are often spatially smooth, so nearby threads still tend to exhibit similar validity patterns .
140141template <typename T>
141142__device__ __inline__ T BilinearInterpolate (
142143 const T* in,
@@ -165,34 +166,33 @@ __device__ __inline__ T BilinearInterpolate(
165166 CoordT hh = static_cast <CoordT>(1 ) - lh;
166167 CoordT hw = static_cast <CoordT>(1 ) - lw;
167168
168- // [Optimization 3]: Avoid a second multiply for base_high.
169- // Original code computed both bases as:
170- // base_low = h_low * width;
171- // base_high = h_high * width;
172- // Since h_high = h_low + 1, we can rewrite base_high as base_low + width and
173- // save one integer multiply in the hot path:
174- // base_low = h_low * width;
175- // base_high = base_low + width;
176- int base_low = h_low * width;
177- int base_high = base_low + width;
178-
179- CoordT v1, v2, v3, v4;
180-
181- // [Optimization 4]: Interior fast-path when all four bilinear neighbors lie inside the image.
182- // In that case we skip four per-neighbor ternary checks (see else). See block comment above for
183- // measured edge-branch rates (~1.4%–4.8% in sampled configs); the majority of threads hit this branch.
184- if (h_low >= 0 && w_low >= 0 && h_high < height && w_high < width) {
185- v1 = Traits::Load (in + base_low + w_low);
186- v2 = Traits::Load (in + base_low + w_high);
187- v3 = Traits::Load (in + base_high + w_low);
188- v4 = Traits::Load (in + base_high + w_high);
189- } else {
190- // Edge / partial OOB: same semantics as before (invalid neighbor contributes 0).
191- v1 = (h_low >= 0 && w_low >= 0 ) ? Traits::Load (in + base_low + w_low) : static_cast <CoordT>(0 );
192- v2 = (h_low >= 0 && w_high < width) ? Traits::Load (in + base_low + w_high) : static_cast <CoordT>(0 );
193- v3 = (h_high < height && w_low >= 0 ) ? Traits::Load (in + base_high + w_low) : static_cast <CoordT>(0 );
194- v4 = (h_high < height && w_high < width) ? Traits::Load (in + base_high + w_high) : static_cast <CoordT>(0 );
195- }
169+ // [Optimization 4]: Branchless neighbor loads via "safe address + validity mask".
170+ // 1) Clamp each coordinate to a legal address first (prevents illegal memory access).
171+ // 2) Compute validity predicates for the true (possibly OOB) coordinates.
172+ // 3) Always load from clamped address and mask invalid neighbors to zero.
173+ // Modern CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
174+ const int safe_h_low = max (0 , min (h_low, height - 1 ));
175+ const int safe_h_high = max (0 , min (h_high, height - 1 ));
176+ const int safe_w_low = max (0 , min (w_low, width - 1 ));
177+ const int safe_w_high = max (0 , min (w_high, width - 1 ));
178+
179+ const int safe_base_low = safe_h_low * width;
180+ const int safe_base_high = safe_h_high * width;
181+
182+ const bool h_low_valid = (h_low >= 0 && h_low < height);
183+ const bool h_high_valid = (h_high >= 0 && h_high < height);
184+ const bool w_low_valid = (w_low >= 0 && w_low < width);
185+ const bool w_high_valid = (w_high >= 0 && w_high < width);
186+
187+ const CoordT m1 = static_cast <CoordT>(h_low_valid && w_low_valid);
188+ const CoordT m2 = static_cast <CoordT>(h_low_valid && w_high_valid);
189+ const CoordT m3 = static_cast <CoordT>(h_high_valid && w_low_valid);
190+ const CoordT m4 = static_cast <CoordT>(h_high_valid && w_high_valid);
191+
192+ const CoordT v1 = Traits::Load (in + safe_base_low + safe_w_low) * m1;
193+ const CoordT v2 = Traits::Load (in + safe_base_low + safe_w_high) * m2;
194+ const CoordT v3 = Traits::Load (in + safe_base_high + safe_w_low) * m3;
195+ const CoordT v4 = Traits::Load (in + safe_base_high + safe_w_high) * m4;
196196
197197 // [Optimization 5]: Factor bilinear into horizontal blends on two rows, then vertical blend.
198198 // Algebraically equivalent to w1*v1 + w2*v2 + w3*v3 + w4*v4 with w1..w4 from hh/hw/lh/lw;
0 commit comments