Skip to content

Commit 7f4e638

Browse files
committed
Optimize deformable bilinear interpolate with invariant-based clamps and weight masking
1 parent 9a9a466 commit 7f4e638

1 file changed

Lines changed: 24 additions & 23 deletions

File tree

onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -188,33 +188,34 @@ __device__ __inline__ T BilinearInterpolate(
188188
CoordT hh = static_cast<CoordT>(1) - lh;
189189
CoordT hw = static_cast<CoordT>(1) - lw;
190190

191-
// [Optimization 4]: Branchless neighbor loads via "safe address + validity mask".
192-
// 1) Clamp each coordinate to a legal address first (prevents illegal memory access).
193-
// 2) Compute validity predicates for the true (possibly OOB) coordinates.
194-
// 3) Always load from clamped address and mask invalid neighbors to zero.
195-
// Modern CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196-
const int safe_h_low = max(0, min(h_low, height - 1));
197-
const int safe_h_high = max(0, min(h_high, height - 1));
198-
const int safe_w_low = max(0, min(w_low, width - 1));
199-
const int safe_w_high = max(0, min(w_high, width - 1));
191+
// [Optimization 3]: Branchless neighbor loads via "safe address + one-sided clamp".
192+
// Given the early return above, coordinates are in (-1, H) x (-1, W), so each index only needs one-sided clamp:
193+
// h_low in [-1, H-1], h_high in [0, H], w_low in [-1, W-1], w_high in [0, W].
194+
// We fuse validity into bilinear 1D weights (hh/lh/hw/lw), then always load from legal addresses.
195+
// CUDA compilers usually lower this to predicated/selp-style code without control-flow branches.
196+
const int safe_h_low = max(0, h_low);
197+
const int safe_h_high = min(h_high, height - 1);
198+
const int safe_w_low = max(0, w_low);
199+
const int safe_w_high = min(w_high, width - 1);
200+
201+
// [Optimization 4]: Mask validity into bilinear 1D weights.
202+
// Reuse the same invariant as above: each weight only needs the single bound that can still fail.
203+
// Masking 1D weights is algebraically equivalent to masking each 2D neighbor contribution.
204+
// Apply conditions directly on weights so the compiler can emit straightforward predicated selects.
205+
// Keep the zero in ComputeT (CoordT) to avoid T->ComputeT implicit conversions for half/BFloat16.
206+
const CoordT zero = static_cast<CoordT>(0);
207+
hh = (h_low >= 0) ? hh : zero;
208+
lh = (h_high < height) ? lh : zero;
209+
hw = (w_low >= 0) ? hw : zero;
210+
lw = (w_high < width) ? lw : zero;
200211

201212
const int safe_base_low = safe_h_low * width;
202213
const int safe_base_high = safe_h_high * width;
203214

204-
const bool h_low_valid = (h_low >= 0 && h_low < height);
205-
const bool h_high_valid = (h_high >= 0 && h_high < height);
206-
const bool w_low_valid = (w_low >= 0 && w_low < width);
207-
const bool w_high_valid = (w_high >= 0 && w_high < width);
208-
209-
const CoordT m1 = static_cast<CoordT>(h_low_valid && w_low_valid);
210-
const CoordT m2 = static_cast<CoordT>(h_low_valid && w_high_valid);
211-
const CoordT m3 = static_cast<CoordT>(h_high_valid && w_low_valid);
212-
const CoordT m4 = static_cast<CoordT>(h_high_valid && w_high_valid);
213-
214-
const CoordT v1 = Traits::Load(in + safe_base_low + safe_w_low) * m1;
215-
const CoordT v2 = Traits::Load(in + safe_base_low + safe_w_high) * m2;
216-
const CoordT v3 = Traits::Load(in + safe_base_high + safe_w_low) * m3;
217-
const CoordT v4 = Traits::Load(in + safe_base_high + safe_w_high) * m4;
215+
const CoordT v1 = Traits::Load(in + safe_base_low + safe_w_low);
216+
const CoordT v2 = Traits::Load(in + safe_base_low + safe_w_high);
217+
const CoordT v3 = Traits::Load(in + safe_base_high + safe_w_low);
218+
const CoordT v4 = Traits::Load(in + safe_base_high + safe_w_high);
218219

219220
// [Optimization 5]: Factor bilinear into horizontal blends on two rows, then vertical blend.
220221
// Algebraically equivalent to w1*v1 + w2*v2 + w3*v3 + w4*v4 with w1..w4 from hh/hw/lh/lw;

0 commit comments

Comments
 (0)