Skip to content

Commit 6d48ced

Browse files
committed
fix(bgz17): guard SIMD gather over-read on final distance-matrix row
The AVX2/AVX-512 batch palette lookup uses an i32 gather with scale=2 over the u16 distance matrix, loading each lane's target u16 PLUS the following u16 into the high half (masked off). On the last valid index of dm_data (query==k-1, candidate k-1) that high half reads one u16 past the backing slice — an out-of-bounds read. Guard both the AVX2 and AVX-512 paths: when the row's worst-case over-read (row_offset + k, candidate k-1's high half) is not strictly inside dm_data, route the whole batch through the scalar path, which only ever touches row_offset + c. Affects at most the final row of a tight k×k matrix (<=1/k of queries). Adds test_batch_last_row_full_width regression: k=64 (exact multiple of both SIMD widths), last row, full-width batch including candidate k-1. https://claude.ai/code/session_01D2WSmezQBNC3bUdHuGfGmo
1 parent 31d7757 commit 6d48ced

1 file changed

Lines changed: 59 additions & 0 deletions

File tree

crates/bgz17/src/simd.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ unsafe fn avx2_batch(dm_data: &[u16], k: usize, query: u8, candidates: &[u8], ou
9393
use core::arch::x86_64::*;
9494

9595
let row_offset = query as usize * k;
96+
97+
// The i32 gather (scale=2 on the u16 base) loads the target u16 AND the
98+
// following u16 into each lane's high half (masked off below). For the last
99+
// valid index of `dm_data` that high half is out of bounds, so when the
100+
// row's worst-case over-read (`row_offset + k`, candidate k-1's high half)
101+
// is not strictly inside the backing slice, route the whole batch through
102+
// the scalar path — which only ever touches `row_offset + c`. This affects
103+
// at most the final row of a tight k×k matrix (≤1/k of queries).
104+
if row_offset + k >= dm_data.len() {
105+
scalar_batch(dm_data, k, query, candidates, out);
106+
return;
107+
}
108+
96109
let row_ptr = dm_data.as_ptr().add(row_offset);
97110
let n = candidates.len();
98111

@@ -160,6 +173,19 @@ unsafe fn avx512_batch(dm_data: &[u16], k: usize, query: u8, candidates: &[u8],
160173
use core::arch::x86_64::*;
161174

162175
let row_offset = query as usize * k;
176+
177+
// The i32 gather (scale=2 on the u16 base) loads the target u16 AND the
178+
// following u16 into each lane's high half (masked off below). For the last
179+
// valid index of `dm_data` that high half is out of bounds, so when the
180+
// row's worst-case over-read (`row_offset + k`, candidate k-1's high half)
181+
// is not strictly inside the backing slice, route the whole batch through
182+
// the scalar path — which only ever touches `row_offset + c`. This affects
183+
// at most the final row of a tight k×k matrix (≤1/k of queries).
184+
if row_offset + k >= dm_data.len() {
185+
scalar_batch(dm_data, k, query, candidates, out);
186+
return;
187+
}
188+
163189
let row_ptr = dm_data.as_ptr().add(row_offset);
164190
let n = candidates.len();
165191

@@ -366,6 +392,39 @@ mod tests {
366392
assert_eq!(out[0], 0);
367393
}
368394

395+
#[test]
396+
fn test_batch_last_row_full_width() {
397+
// Regression: the SIMD gather (i32 gather, scale=2 over the u16 matrix)
398+
// over-reads the u16 following each lane's target. On the very last row
399+
// (query == k-1) with the last candidate (k-1), that over-read lands one
400+
// u16 past `dm.data` — a memory-safety bug. k=64 makes 64 candidates an
401+
// exact multiple of both the AVX-512 (16) and AVX2 (8) widths, so the
402+
// last candidate is processed by the gather path, not the scalar tail.
403+
// The boundary guard must route this row to scalar and still be correct.
404+
let pal = make_palette(64);
405+
let dm = DistanceMatrix::build(&pal);
406+
assert_eq!(
407+
dm.data.len(),
408+
dm.k * dm.k,
409+
"tight k×k matrix (no trailing slack)"
410+
);
411+
412+
let query = (dm.k - 1) as u8; // last row
413+
let candidates: Vec<u8> = (0..dm.k as u8).collect(); // includes k-1
414+
let mut batch_out = vec![0u16; dm.k];
415+
416+
batch_palette_distance(&dm.data, dm.k, query, &candidates, &mut batch_out);
417+
418+
for (i, &cand) in candidates.iter().enumerate() {
419+
let expected = dm.distance(query, cand);
420+
assert_eq!(
421+
batch_out[i], expected,
422+
"last-row mismatch at candidate {}: batch={} scalar={}",
423+
cand, batch_out[i], expected
424+
);
425+
}
426+
}
427+
369428
#[test]
370429
fn test_detect_simd() {
371430
let level = detect_simd();

0 commit comments

Comments
 (0)