Skip to content

Commit 995f6e3

Browse files
committed
kernels/custom: grid_sampler_2d fp16 — accumulate in fp32
Match the precision of the portable kernel (after pytorch#19117) and avoid fp16 catastrophic cancellation on weight computation. The NEON half variant previously did interpolation weight computation and FMA accumulation in fp16 via vmul_f16 / vfma_f16; this change loads fp16, promotes to float32x4 via vcvt_f32_f16, does the four-corner FMA chain in fp32, and casts back to fp16 on store. Speed impact: two vcvt per 4-channel group — single-cycle on modern ARM, unmeasurable at op level in a full-model benchmark (3.5 ms for a typical call shape, unchanged). Precision impact: max_abs vs an fp32-then-down-cast reference drops from ~0.1 to 0 on the shapes the polycam depth model uses.
1 parent b5a2967 commit 995f6e3

1 file changed

Lines changed: 15 additions & 13 deletions

File tree

kernels/custom/neon_grid_sampler_2d.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ inline void bilinear_sample_all_channels(
146146
}
147147

148148
// Process one output spatial location for all channels using NEON, fp16 variant.
149+
// Math happens in fp32 (matches the portable kernel's precision); loads/stores
150+
// are fp16. vcvt_f32_f16 / vcvt_f16_f32 are single-cycle on modern ARM.
149151
inline void bilinear_sample_all_channels_fp16(
150152
const __fp16* input_n,
151153
__fp16* output_n,
@@ -161,10 +163,10 @@ inline void bilinear_sample_all_channels_fp16(
161163
float fx = gx - static_cast<float>(x0);
162164
float fy = gy - static_cast<float>(y0);
163165

164-
float16x4_t vw_tl = vdup_n_f16(static_cast<__fp16>((1.0f - fx) * (1.0f - fy)));
165-
float16x4_t vw_tr = vdup_n_f16(static_cast<__fp16>(fx * (1.0f - fy)));
166-
float16x4_t vw_bl = vdup_n_f16(static_cast<__fp16>((1.0f - fx) * fy));
167-
float16x4_t vw_br = vdup_n_f16(static_cast<__fp16>(fx * fy));
166+
float32x4_t vw_tl = vdupq_n_f32((1.0f - fx) * (1.0f - fy));
167+
float32x4_t vw_tr = vdupq_n_f32(fx * (1.0f - fy));
168+
float32x4_t vw_bl = vdupq_n_f32((1.0f - fx) * fy);
169+
float32x4_t vw_br = vdupq_n_f32(fx * fy);
168170

169171
bool tl_valid = (static_cast<unsigned>(x0) < static_cast<unsigned>(W_in) &&
170172
static_cast<unsigned>(y0) < static_cast<unsigned>(H_in));
@@ -214,18 +216,18 @@ inline void bilinear_sample_all_channels_fp16(
214216
br[2] = p2[off_br]; br[3] = p3[off_br];
215217
}
216218

217-
float16x4_t v_tl = vld1_f16(tl);
218-
float16x4_t v_tr = vld1_f16(tr);
219-
float16x4_t v_bl = vld1_f16(bl);
220-
float16x4_t v_br = vld1_f16(br);
219+
float32x4_t v_tl = vcvt_f32_f16(vld1_f16(tl));
220+
float32x4_t v_tr = vcvt_f32_f16(vld1_f16(tr));
221+
float32x4_t v_bl = vcvt_f32_f16(vld1_f16(bl));
222+
float32x4_t v_br = vcvt_f32_f16(vld1_f16(br));
221223

222-
float16x4_t result = vmul_f16(vw_tl, v_tl);
223-
result = vfma_f16(result, vw_tr, v_tr);
224-
result = vfma_f16(result, vw_bl, v_bl);
225-
result = vfma_f16(result, vw_br, v_br);
224+
float32x4_t result = vmulq_f32(vw_tl, v_tl);
225+
result = vfmaq_f32(result, vw_tr, v_tr);
226+
result = vfmaq_f32(result, vw_bl, v_bl);
227+
result = vfmaq_f32(result, vw_br, v_br);
226228

227229
__fp16 res[4];
228-
vst1_f16(res, result);
230+
vst1_f16(res, vcvt_f16_f32(result));
229231
output_n[(c + 0) * spatial_out + out_off] = res[0];
230232
output_n[(c + 1) * spatial_out + out_off] = res[1];
231233
output_n[(c + 2) * spatial_out + out_off] = res[2];

0 commit comments

Comments
 (0)