Skip to content

Commit 7de925b

Browse files
asg017claude
andcommitted
Fix int16 overflow in l2_sqr_int8_neon SIMD distance
vmulq_s16(diff, diff) produced int16 results, but diff can be up to 255 for int8 vectors (-128 vs 127), and 255^2 = 65025 overflows int16 (max 32767). This caused NaN/wrong results for int8 vectors with large differences. Fix: use vmull_s16 (widening multiply) to produce int32 results directly, avoiding the intermediate int16 overflow. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4bee883 commit 7de925b

2 files changed

Lines changed: 14 additions & 5 deletions

File tree

sqlite-vec.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,16 @@ static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
258258
pVect1 += 8;
259259
pVect2 += 8;
260260

261-
// widen to protect against overflow
261+
// widen i8 to i16 for subtraction
262262
int16x8_t v1_wide = vmovl_s8(v1);
263263
int16x8_t v2_wide = vmovl_s8(v2);
264-
265264
int16x8_t diff = vsubq_s16(v1_wide, v2_wide);
266-
int16x8_t squared_diff = vmulq_s16(diff, diff);
267-
int32x4_t sum = vpaddlq_s16(squared_diff);
265+
266+
// widening multiply: i16*i16 -> i32 to avoid i16 overflow
267+
// (diff can be up to 255, so diff*diff can be up to 65025 > INT16_MAX)
268+
int32x4_t sq_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
269+
int32x4_t sq_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
270+
int32x4_t sum = vaddq_s32(sq_lo, sq_hi);
268271

269272
sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) +
270273
vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3);

tests/test-loadable.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,17 @@ def check(a, b, dtype=np.float32):
381381

382382
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
383383
y = npy_l2(np.array(a), np.array(b))
384-
assert isclose(x, y, abs_tol=1e-6)
384+
assert isclose(x, y, rel_tol=1e-5, abs_tol=1e-6)
385385

386386
check([1.2, 0.1], [0.4, -0.4])
387387
check([-1.2, -0.1], [-0.4, 0.4])
388388
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
389+
# Extreme int8 values: diff=255, squared=65025 which overflows i16
390+
# This tests the NEON widening multiply fix (slight float rounding expected)
391+
check([-128] * 8, [127] * 8, dtype=np.int8)
392+
check([-128] * 16, [127] * 16, dtype=np.int8)
393+
check([-128, 127, -128, 127, -128, 127, -128, 127],
394+
[127, -128, 127, -128, 127, -128, 127, -128], dtype=np.int8)
389395

390396

391397
def test_vec_length():

0 commit comments

Comments
 (0)