Skip to content

Commit 0108b96

Browse files
perf: speed up RaBitQ 4-bit LUT distance on ARM by 16x (lance-format#6537)
1 parent d0124ed commit 0108b96

4 files changed

Lines changed: 292 additions & 12 deletions

File tree

rust/lance-index/src/vector/bq/storage.rs

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -348,17 +348,8 @@ impl DistCalculator for RabitDistCalculator<'_> {
348348
let id = id as usize;
349349
let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
350350
let num_vectors = self.codes.len() / code_len;
351-
let code = get_rq_code(self.codes, id, num_vectors, code_len);
352-
let dist = code
353-
.zip(self.dist_table.chunks_exact(SEGMENT_NUM_CODES).tuples())
354-
.map(|(code_byte, (dist_table, next_dist_table))| {
355-
// code is a bit vector, we iterate over 8 bits at a time,
356-
// every 4 bits is a sub-vector, we need to extract the bits
357-
let current_code = (code_byte & 0x0F) as usize;
358-
let next_code = (code_byte >> 4) as usize;
359-
dist_table[current_code] + next_dist_table[next_code]
360-
})
361-
.sum::<f32>();
351+
let dist =
352+
compute_single_rq_distance(self.codes, id, num_vectors, code_len, &self.dist_table);
362353

363354
// distance between quantized vector and query vector
364355
let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d;
@@ -799,6 +790,60 @@ impl QuantizerStorage for RabitQuantizationStorage {
799790
}
800791
}
801792

793+
/// Compute the raw distance for a single vector without allocating.
794+
///
795+
/// Fuses code extraction from the packed layout with distance accumulation
796+
/// in a single pass, avoiding the intermediate `Vec` allocation that
797+
/// `get_rq_code` + iterator would require.
798+
#[inline]
799+
fn compute_single_rq_distance(
800+
codes: &[u8],
801+
id: usize,
802+
num_vectors: usize,
803+
num_code_bytes: usize,
804+
dist_table: &[f32],
805+
) -> f32 {
806+
let remainder = num_vectors % BATCH_SIZE;
807+
let mut dist_table_iter = dist_table.chunks_exact(SEGMENT_NUM_CODES).tuples();
808+
809+
if id < num_vectors - remainder {
810+
let batch_codes = &codes[id / BATCH_SIZE * BATCH_SIZE * num_code_bytes
811+
..(id / BATCH_SIZE + 1) * BATCH_SIZE * num_code_bytes];
812+
813+
let id_in_batch = id % BATCH_SIZE;
814+
let idx = PERM0_INVERSE[id_in_batch % 16];
815+
let is_lower = id_in_batch < 16;
816+
817+
let mut dist = 0.0f32;
818+
for block in batch_codes.chunks_exact(BATCH_SIZE) {
819+
let code_byte = if is_lower {
820+
(block[idx] & 0xF) | (block[idx + 16] << 4)
821+
} else {
822+
(block[idx] >> 4) | (block[idx + 16] & 0xF0)
823+
};
824+
if let Some((current_dt, next_dt)) = dist_table_iter.next() {
825+
let current_code = (code_byte & 0x0F) as usize;
826+
let next_code = (code_byte >> 4) as usize;
827+
dist += current_dt[current_code] + next_dt[next_code];
828+
}
829+
}
830+
dist
831+
} else {
832+
let offset_id = id - (num_vectors - remainder);
833+
let remainder_codes = &codes[(num_vectors - remainder) * num_code_bytes..];
834+
835+
let mut dist = 0.0f32;
836+
for &code_byte in remainder_codes.iter().skip(offset_id).step_by(remainder) {
837+
if let Some((current_dt, next_dt)) = dist_table_iter.next() {
838+
let current_code = (code_byte & 0x0F) as usize;
839+
let next_code = (code_byte >> 4) as usize;
840+
dist += current_dt[current_code] + next_dt[next_code];
841+
}
842+
}
843+
dist
844+
}
845+
}
846+
802847
#[inline]
803848
fn get_rq_code(
804849
codes: &[u8],

rust/lance-linalg/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,9 @@ harness = false
6262
name = "norm_l2"
6363
harness = false
6464

65+
[[bench]]
66+
name = "dist_table"
67+
harness = false
68+
6569
[lints]
6670
workspace = true
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright The Lance Authors
3+
4+
//! Benchmark of 4-bit LUT distance table summation (RaBitQ inner loop).
5+
//!
6+
//! Measures both the dispatched path (NEON on ARM, AVX2 on x86) and the
7+
//! scalar fallback, so the speedup is visible in a single benchmark run.
8+
9+
use std::iter::repeat_with;
10+
11+
use criterion::{Criterion, black_box, criterion_group, criterion_main};
12+
use lance_linalg::simd::dist_table::{BATCH_SIZE, sum_4bit_dist_table, sum_4bit_dist_table_scalar};
13+
use rand::Rng;
14+
15+
fn bench_sum_4bit_dist_table(c: &mut Criterion) {
16+
let mut rng = rand::rng();
17+
18+
// code_len = dim / 8 for 1-bit quantization
19+
for (label, n_vectors, code_len) in [
20+
("32vec_dim128", 32_usize, 16_usize),
21+
("32vec_dim1536", 32, 192),
22+
("32vec_dim4096", 32, 512),
23+
("32vec_dim65536", 32, 8192),
24+
("16Kvec_dim128", 16_000, 16),
25+
("16Kvec_dim1536", 16_000, 192),
26+
] {
27+
let n = n_vectors.div_ceil(BATCH_SIZE) * BATCH_SIZE;
28+
29+
let codes: Vec<u8> = repeat_with(|| rng.random::<u8>())
30+
.take(n * code_len)
31+
.collect();
32+
33+
let dist_table: Vec<u8> = repeat_with(|| rng.random::<u8>())
34+
.take(BATCH_SIZE * code_len)
35+
.collect();
36+
37+
let mut dists = vec![0u16; n];
38+
39+
// Dispatched path (NEON on ARM, AVX2 on x86)
40+
c.bench_function(&format!("sum_4bit_dist_table/simd/{}", label), |b| {
41+
b.iter(|| {
42+
dists.fill(0);
43+
sum_4bit_dist_table(n, code_len, &codes, &dist_table, &mut dists);
44+
black_box(&dists);
45+
})
46+
});
47+
48+
// Scalar reference path
49+
c.bench_function(&format!("sum_4bit_dist_table/scalar/{}", label), |b| {
50+
b.iter(|| {
51+
dists.fill(0);
52+
sum_4bit_dist_table_scalar(code_len, &codes, &dist_table, &mut dists);
53+
black_box(&dists);
54+
})
55+
});
56+
}
57+
}
58+
59+
criterion_group!(
60+
name = benches;
61+
config = Criterion::default().significance_level(0.1).sample_size(10);
62+
targets = bench_sum_4bit_dist_table
63+
);
64+
criterion_main!(benches);

rust/lance-linalg/src/simd/dist_table.rs

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright The Lance Authors
33

4+
#[cfg(target_arch = "aarch64")]
5+
use std::arch::aarch64::*;
46
#[cfg(target_arch = "x86_64")]
57
use std::arch::x86_64::*;
68

@@ -57,13 +59,28 @@ pub fn sum_4bit_dist_table(
5759
)
5860
}
5961
},
62+
#[cfg(target_arch = "aarch64")]
63+
SimdSupport::Neon => unsafe {
64+
for i in (0..n).step_by(BATCH_SIZE) {
65+
sum_dist_table_32bytes_batch_neon(
66+
&codes[i * code_len..(i + BATCH_SIZE) * code_len],
67+
dist_table,
68+
&mut dists[i..i + BATCH_SIZE],
69+
)
70+
}
71+
},
6072
_ => sum_4bit_dist_table_scalar(code_len, codes, dist_table, dists),
6173
}
6274
}
6375

6476
#[inline]
6577
#[allow(unused)]
66-
fn sum_4bit_dist_table_scalar(code_len: usize, codes: &[u8], dist_table: &[u8], dists: &mut [u16]) {
78+
pub fn sum_4bit_dist_table_scalar(
79+
code_len: usize,
80+
codes: &[u8],
81+
dist_table: &[u8],
82+
dists: &mut [u16],
83+
) {
6784
for (vec_block_idx, blocks) in codes.chunks_exact(BATCH_SIZE * code_len).enumerate() {
6885
for (sub_vec_idx, block) in blocks.chunks_exact(BATCH_SIZE).enumerate() {
6986
let current_dist_table = &dist_table[sub_vec_idx * 2 * 16..(sub_vec_idx * 2 + 1) * 16];
@@ -159,6 +176,77 @@ unsafe fn sum_dist_table_32bytes_batch_avx2(codes: &[u8], dist_table: &[u8], dis
159176
_mm256_storeu_si256(dists.as_mut_ptr().add(16) as *mut __m256i, dis1);
160177
}
161178

179+
#[cfg(target_arch = "aarch64")]
180+
#[inline]
181+
unsafe fn sum_dist_table_32bytes_batch_neon(codes: &[u8], dist_table: &[u8], dists: &mut [u16]) {
182+
let low_mask = vdupq_n_u8(0x0f);
183+
184+
// 8 accumulators: 4 per 128-bit "lane" (lo = bytes 0..16, hi = bytes 16..32 of each block)
185+
let mut accu0_lo = vdupq_n_u16(0);
186+
let mut accu1_lo = vdupq_n_u16(0);
187+
let mut accu2_lo = vdupq_n_u16(0);
188+
let mut accu3_lo = vdupq_n_u16(0);
189+
let mut accu0_hi = vdupq_n_u16(0);
190+
let mut accu1_hi = vdupq_n_u16(0);
191+
let mut accu2_hi = vdupq_n_u16(0);
192+
let mut accu3_hi = vdupq_n_u16(0);
193+
194+
let codes_ptr = codes.as_ptr();
195+
let dt_ptr = dist_table.as_ptr();
196+
197+
for i in (0..codes.len()).step_by(32) {
198+
// Process lo lane: bytes [i..i+16]
199+
let c_lo = vld1q_u8(codes_ptr.add(i));
200+
let lut_lo = vld1q_u8(dt_ptr.add(i));
201+
202+
let lo_lo = vandq_u8(c_lo, low_mask);
203+
let hi_lo = vshrq_n_u8::<4>(c_lo);
204+
205+
let res_lo_lo = vqtbl1q_u8(lut_lo, lo_lo);
206+
let res_hi_lo = vqtbl1q_u8(lut_lo, hi_lo);
207+
208+
accu0_lo = vaddq_u16(accu0_lo, vreinterpretq_u16_u8(res_lo_lo));
209+
accu1_lo = vaddq_u16(accu1_lo, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_lo_lo)));
210+
accu2_lo = vaddq_u16(accu2_lo, vreinterpretq_u16_u8(res_hi_lo));
211+
accu3_lo = vaddq_u16(accu3_lo, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_hi_lo)));
212+
213+
// Process hi lane: bytes [i+16..i+32]
214+
let c_hi = vld1q_u8(codes_ptr.add(i + 16));
215+
let lut_hi = vld1q_u8(dt_ptr.add(i + 16));
216+
217+
let lo_hi = vandq_u8(c_hi, low_mask);
218+
let hi_hi = vshrq_n_u8::<4>(c_hi);
219+
220+
let res_lo_hi = vqtbl1q_u8(lut_hi, lo_hi);
221+
let res_hi_hi = vqtbl1q_u8(lut_hi, hi_hi);
222+
223+
accu0_hi = vaddq_u16(accu0_hi, vreinterpretq_u16_u8(res_lo_hi));
224+
accu1_hi = vaddq_u16(accu1_hi, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_lo_hi)));
225+
accu2_hi = vaddq_u16(accu2_hi, vreinterpretq_u16_u8(res_hi_hi));
226+
accu3_hi = vaddq_u16(accu3_hi, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_hi_hi)));
227+
}
228+
229+
// Merge: clean even bytes by subtracting the odd-byte bleed
230+
accu0_lo = vsubq_u16(accu0_lo, vshlq_n_u16::<8>(accu1_lo));
231+
accu0_hi = vsubq_u16(accu0_hi, vshlq_n_u16::<8>(accu1_hi));
232+
233+
// Cross-lane merge: add lo and hi lane accumulators
234+
// This is the NEON equivalent of AVX2's permute2f128 + blend + add
235+
let dis0_even = vaddq_u16(accu0_lo, accu0_hi);
236+
let dis0_odd = vaddq_u16(accu1_lo, accu1_hi);
237+
vst1q_u16(dists.as_mut_ptr(), dis0_even);
238+
vst1q_u16(dists.as_mut_ptr().add(8), dis0_odd);
239+
240+
// Same for hi-nibble accumulators (vectors 16..31)
241+
accu2_lo = vsubq_u16(accu2_lo, vshlq_n_u16::<8>(accu3_lo));
242+
accu2_hi = vsubq_u16(accu2_hi, vshlq_n_u16::<8>(accu3_hi));
243+
244+
let dis1_even = vaddq_u16(accu2_lo, accu2_hi);
245+
let dis1_odd = vaddq_u16(accu3_lo, accu3_hi);
246+
vst1q_u16(dists.as_mut_ptr().add(16), dis1_even);
247+
vst1q_u16(dists.as_mut_ptr().add(24), dis1_odd);
248+
}
249+
162250
// We implement the AVX512 version in C because AVX512 is not stable yet in Rust,
163251
// implement it in Rust once we upgrade rust to 1.89.0.
164252
unsafe extern "C" {
@@ -214,4 +302,83 @@ mod tests {
214302
// so the distance is 2 * (dist_table[0x6] + dist_table[0xb + 16]) = 2*(7 + 12) = 38
215303
assert_eq!(dists[1], 38);
216304
}
305+
306+
/// Test that the SIMD path (NEON on ARM, AVX2 on x86) produces identical
307+
/// results to the scalar reference across a range of dimensions, including
308+
/// very large ones (up to DIM=65536).
309+
///
310+
/// Note: dist_table values are capped to avoid u16 overflow, matching
311+
/// production behavior where values are quantized to a small range.
312+
/// (The scalar path uses saturating_add while SIMD uses wrapping add,
313+
/// so they diverge on overflow — but overflow never occurs with real
314+
/// quantized data.)
315+
#[test]
316+
fn test_simd_matches_scalar_varied_dimensions() {
317+
use rand::{Rng, SeedableRng};
318+
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
319+
320+
// code_len = dim / 8 for 1-bit quantization; we test various code_lens
321+
// directly since that's what the function sees.
322+
// code_len=16 → DIM=128, code_len=192 → DIM=1536,
323+
// code_len=512 → DIM=4096, code_len=8192 → DIM=65536
324+
for code_len in [2, 16, 96, 192, 512, 1024, 8192] {
325+
let n = BATCH_SIZE; // 32 vectors per batch
326+
327+
// Each code byte produces 2 lookups; cap values so
328+
// 2 * code_len * max_val < u16::MAX.
329+
let max_val = (u16::MAX as usize / (2 * code_len)).min(255) as u8;
330+
331+
let codes: Vec<u8> = (0..n * code_len).map(|_| rng.random::<u8>()).collect();
332+
let dist_table: Vec<u8> = (0..BATCH_SIZE * code_len)
333+
.map(|_| rng.random_range(0..=max_val))
334+
.collect();
335+
336+
let mut expected = vec![0u16; n];
337+
sum_4bit_dist_table_scalar(code_len, &codes, &dist_table, &mut expected);
338+
339+
let mut actual = vec![0u16; n];
340+
sum_4bit_dist_table(n, code_len, &codes, &dist_table, &mut actual);
341+
342+
assert_eq!(
343+
actual,
344+
expected,
345+
"SIMD and scalar mismatch for code_len={} (DIM={})",
346+
code_len,
347+
code_len * 8,
348+
);
349+
}
350+
}
351+
352+
/// Test with multiple batches to verify accumulation across batch boundaries.
353+
#[test]
354+
fn test_simd_matches_scalar_multi_batch() {
355+
use rand::{Rng, SeedableRng};
356+
let mut rng = rand::rngs::StdRng::seed_from_u64(123);
357+
358+
for code_len in [16, 192, 1024] {
359+
let n = BATCH_SIZE * 10; // 320 vectors = 10 batches
360+
361+
let max_val = (u16::MAX as usize / (2 * code_len)).min(255) as u8;
362+
363+
let codes: Vec<u8> = (0..n * code_len).map(|_| rng.random::<u8>()).collect();
364+
let dist_table: Vec<u8> = (0..BATCH_SIZE * code_len)
365+
.map(|_| rng.random_range(0..=max_val))
366+
.collect();
367+
368+
let mut expected = vec![0u16; n];
369+
sum_4bit_dist_table_scalar(code_len, &codes, &dist_table, &mut expected);
370+
371+
let mut actual = vec![0u16; n];
372+
sum_4bit_dist_table(n, code_len, &codes, &dist_table, &mut actual);
373+
374+
assert_eq!(
375+
actual,
376+
expected,
377+
"SIMD and scalar mismatch for multi-batch code_len={} (DIM={}, n={})",
378+
code_len,
379+
code_len * 8,
380+
n,
381+
);
382+
}
383+
}
217384
}

0 commit comments

Comments
 (0)