Skip to content

Commit 37099b6

Browse files
committed
Optimize BilinearInterpolate with one-sided bounds and float mask selp
1 parent 9a9a466 commit 37099b6

1 file changed

Lines changed: 19 additions & 19 deletions

File tree

onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -188,29 +188,29 @@ __device__ __inline__ T BilinearInterpolate(
188188
CoordT hh = static_cast<CoordT>(1) - lh;
189189
CoordT hw = static_cast<CoordT>(1) - lw;
190190

191-
// [Optimization 4]: Branchless neighbor loads via "safe address + validity mask".
192-
// 1) Clamp each coordinate to a legal address first (prevents illegal memory access).
193-
// 2) Compute validity predicates for the true (possibly OOB) coordinates.
194-
// 3) Always load from clamped address and mask invalid neighbors to zero.
195-
// Modern CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196-
const int safe_h_low = max(0, min(h_low, height - 1));
197-
const int safe_h_high = max(0, min(h_high, height - 1));
198-
const int safe_w_low = max(0, min(w_low, width - 1));
199-
const int safe_w_high = max(0, min(w_high, width - 1));
191+
// [Optimization 3]: Branchless neighbor loads via "safe address + one-sided clamp".
192+
// Given the early return above, coordinates are in (-1, H) x (-1, W), so each index only needs one-sided clamp:
193+
// h_low in [-1, H-1], h_high in [0, H], w_low in [-1, W-1], w_high in [0, W].
194+
// We always load from legal addresses; validity is applied by 2D neighbor masks below.
195+
// CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196+
const int safe_h_low = max(0, h_low);
197+
const int safe_h_high = min(h_high, height - 1);
198+
const int safe_w_low = max(0, w_low);
199+
const int safe_w_high = min(w_high, width - 1);
200+
201+
// [Optimization 4]: One-sided validity checks under the same invariant.
202+
// Keep 2D neighbor masks (m1..m4), algebraically equivalent to masking invalid neighbor terms to zero.
203+
// Use one/zero ternaries directly in CoordT to encourage selp.f32/f16 generation.
204+
const CoordT one = static_cast<CoordT>(1);
205+
const CoordT zero = static_cast<CoordT>(0);
206+
const CoordT m1 = (h_low >= 0 && w_low >= 0) ? one : zero;
207+
const CoordT m2 = (h_low >= 0 && w_high < width) ? one : zero;
208+
const CoordT m3 = (h_high < height && w_low >= 0) ? one : zero;
209+
const CoordT m4 = (h_high < height && w_high < width) ? one : zero;
200210

201211
const int safe_base_low = safe_h_low * width;
202212
const int safe_base_high = safe_h_high * width;
203213

204-
const bool h_low_valid = (h_low >= 0 && h_low < height);
205-
const bool h_high_valid = (h_high >= 0 && h_high < height);
206-
const bool w_low_valid = (w_low >= 0 && w_low < width);
207-
const bool w_high_valid = (w_high >= 0 && w_high < width);
208-
209-
const CoordT m1 = static_cast<CoordT>(h_low_valid && w_low_valid);
210-
const CoordT m2 = static_cast<CoordT>(h_low_valid && w_high_valid);
211-
const CoordT m3 = static_cast<CoordT>(h_high_valid && w_low_valid);
212-
const CoordT m4 = static_cast<CoordT>(h_high_valid && w_high_valid);
213-
214214
const CoordT v1 = Traits::Load(in + safe_base_low + safe_w_low) * m1;
215215
const CoordT v2 = Traits::Load(in + safe_base_low + safe_w_high) * m2;
216216
const CoordT v3 = Traits::Load(in + safe_base_high + safe_w_low) * m3;

0 commit comments

Comments
 (0)