@@ -188,33 +188,34 @@ __device__ __inline__ T BilinearInterpolate(
188188 CoordT hh = static_cast <CoordT>(1 ) - lh;
189189 CoordT hw = static_cast <CoordT>(1 ) - lw;
190190
191- // [Optimization 4]: Branchless neighbor loads via "safe address + validity mask".
192- // 1) Clamp each coordinate to a legal address first (prevents illegal memory access).
193- // 2) Compute validity predicates for the true (possibly OOB) coordinates.
194- // 3) Always load from clamped address and mask invalid neighbors to zero.
195- // Modern CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196- const int safe_h_low = max (0 , min (h_low, height - 1 ));
197- const int safe_h_high = max (0 , min (h_high, height - 1 ));
198- const int safe_w_low = max (0 , min (w_low, width - 1 ));
199- const int safe_w_high = max (0 , min (w_high, width - 1 ));
191+ // [Optimization 3]: Branchless neighbor loads via "safe address + one-sided clamp".
192+ // Given the early return above, coordinates are in (-1, H) x (-1, W), so each index only needs one-sided clamp:
193+ // h_low in [-1, H-1], h_high in [0, H], w_low in [-1, W-1], w_high in [0, W].
194+ // We fuse validity into bilinear 1D weights (hh/lh/hw/lw), then always load from legal addresses.
195+ // CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196+ const int safe_h_low = max (0 , h_low);
197+ const int safe_h_high = min (h_high, height - 1 );
198+ const int safe_w_low = max (0 , w_low);
199+ const int safe_w_high = min (w_high, width - 1 );
200+
201+ // [Optimization 4]: Mask validity into bilinear 1D weights.
202+ // Reuse the same invariant as above: each weight only needs the single bound that can still fail.
203+ // Masking 1D weights is algebraically equivalent to masking each 2D neighbor contribution.
204+ // Apply conditions directly on weights so the compiler can emit straightforward predicated selects.
205+ // Keep the zero in ComputeT (CoordT) to avoid T->ComputeT implicit conversions for half/BFloat16.
206+ const CoordT zero = static_cast <CoordT>(0 );
207+ hh = (h_low >= 0 ) ? hh : zero;
208+ lh = (h_high < height) ? lh : zero;
209+ hw = (w_low >= 0 ) ? hw : zero;
210+ lw = (w_high < width) ? lw : zero;
200211
201212 const int safe_base_low = safe_h_low * width;
202213 const int safe_base_high = safe_h_high * width;
203214
204- const bool h_low_valid = (h_low >= 0 && h_low < height);
205- const bool h_high_valid = (h_high >= 0 && h_high < height);
206- const bool w_low_valid = (w_low >= 0 && w_low < width);
207- const bool w_high_valid = (w_high >= 0 && w_high < width);
208-
209- const CoordT m1 = static_cast <CoordT>(h_low_valid && w_low_valid);
210- const CoordT m2 = static_cast <CoordT>(h_low_valid && w_high_valid);
211- const CoordT m3 = static_cast <CoordT>(h_high_valid && w_low_valid);
212- const CoordT m4 = static_cast <CoordT>(h_high_valid && w_high_valid);
213-
214- const CoordT v1 = Traits::Load (in + safe_base_low + safe_w_low) * m1;
215- const CoordT v2 = Traits::Load (in + safe_base_low + safe_w_high) * m2;
216- const CoordT v3 = Traits::Load (in + safe_base_high + safe_w_low) * m3;
217- const CoordT v4 = Traits::Load (in + safe_base_high + safe_w_high) * m4;
215+ const CoordT v1 = Traits::Load (in + safe_base_low + safe_w_low);
216+ const CoordT v2 = Traits::Load (in + safe_base_low + safe_w_high);
217+ const CoordT v3 = Traits::Load (in + safe_base_high + safe_w_low);
218+ const CoordT v4 = Traits::Load (in + safe_base_high + safe_w_high);
218219
219220 // [Optimization 5]: Factor bilinear into horizontal blends on two rows, then vertical blend.
220221 // Algebraically equivalent to w1*v1 + w2*v2 + w3*v3 + w4*v4 with w1..w4 from hh/hw/lh/lw;
0 commit comments