@@ -188,29 +188,29 @@ __device__ __inline__ T BilinearInterpolate(
188188 CoordT hh = static_cast <CoordT>(1 ) - lh;
189189 CoordT hw = static_cast <CoordT>(1 ) - lw;
190190
191- // [Optimization 4]: Branchless neighbor loads via "safe address + validity mask".
192- // 1) Clamp each coordinate to a legal address first (prevents illegal memory access).
193- // 2) Compute validity predicates for the true (possibly OOB) coordinates.
194- // 3) Always load from clamped address and mask invalid neighbors to zero.
195- // Modern CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196- const int safe_h_low = max (0 , min (h_low, height - 1 ));
197- const int safe_h_high = max (0 , min (h_high, height - 1 ));
198- const int safe_w_low = max (0 , min (w_low, width - 1 ));
199- const int safe_w_high = max (0 , min (w_high, width - 1 ));
191+ // [Optimization 3]: Branchless neighbor loads via "safe address + one-sided clamp".
192+ // Given the early return above, coordinates are in (-1, H) x (-1, W), so each index only needs one-sided clamp:
193+ // h_low in [-1, H-1], h_high in [0, H], w_low in [-1, W-1], w_high in [0, W].
194+ // We always load from legal addresses; validity is applied by 2D neighbor masks below.
195+ // CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196+ const int safe_h_low = max (0 , h_low);
197+ const int safe_h_high = min (h_high, height - 1 );
198+ const int safe_w_low = max (0 , w_low);
199+ const int safe_w_high = min (w_high, width - 1 );
200+
201+ // [Optimization 4]: One-sided validity checks under the same invariant.
202+ // Keep 2D neighbor masks (m1..m4), algebraically equivalent to masking invalid neighbor terms to zero.
203+ // Use one/zero ternaries directly in CoordT to encourage selp.f32/f16 generation.
204+ const CoordT one = static_cast <CoordT>(1 );
205+ const CoordT zero = static_cast <CoordT>(0 );
206+ const CoordT m1 = (h_low >= 0 && w_low >= 0 ) ? one : zero;
207+ const CoordT m2 = (h_low >= 0 && w_high < width) ? one : zero;
208+ const CoordT m3 = (h_high < height && w_low >= 0 ) ? one : zero;
209+ const CoordT m4 = (h_high < height && w_high < width) ? one : zero;
200210
201211 const int safe_base_low = safe_h_low * width;
202212 const int safe_base_high = safe_h_high * width;
203213
204- const bool h_low_valid = (h_low >= 0 && h_low < height);
205- const bool h_high_valid = (h_high >= 0 && h_high < height);
206- const bool w_low_valid = (w_low >= 0 && w_low < width);
207- const bool w_high_valid = (w_high >= 0 && w_high < width);
208-
209- const CoordT m1 = static_cast <CoordT>(h_low_valid && w_low_valid);
210- const CoordT m2 = static_cast <CoordT>(h_low_valid && w_high_valid);
211- const CoordT m3 = static_cast <CoordT>(h_high_valid && w_low_valid);
212- const CoordT m4 = static_cast <CoordT>(h_high_valid && w_high_valid);
213-
214214 const CoordT v1 = Traits::Load (in + safe_base_low + safe_w_low) * m1;
215215 const CoordT v2 = Traits::Load (in + safe_base_low + safe_w_high) * m2;
216216 const CoordT v3 = Traits::Load (in + safe_base_high + safe_w_low) * m3;
0 commit comments