Skip to content

Commit 8f479db

Browse files
perf: add SIMD-accelerated u8 dot product for SQ distance (#6506)
## Summary - Add AVX-512 VNNI and AVX2 backends for unsigned int8 dot product with runtime CPU feature detection and automatic fallback to scalar - Replace the unoptimized `Dot<u8>` impl (which had an explicit TODO) with dispatched SIMD kernel - All existing callers including SQ distance computation benefit automatically with zero changes to lance-index ## Details ### New file: `rust/lance-linalg/src/distance/dot_u8.rs` Three backends selected at runtime via `OnceLock` + `is_x86_feature_detected!`: | Backend | Instruction | Elements/iter | CPU | |---|---|---|---| | AVX-512 VNNI | `VPDPBUSD` + XOR-0x80 bias trick | 64 | Ice Lake+ / Zen 4+ | | AVX2 | `VPMADDWD` on zero-extended u16 | 32 | Haswell+ / Zen 1+ | | Scalar | portable reference | - | any (including ARM) | ### The VNNI bias trick `VPDPBUSD` expects one unsigned and one signed operand, but SQ vectors are u8×u8. We XOR one operand with 0x80 to map it to the signed domain, then correct by adding `128·Σa` at the end. The correction uses `VPSADBW` which runs on execution port 5 while `VPDPBUSD` runs on port 0 — they execute in parallel every cycle, making the correction effectively free. ### SQ integration (automatic) `SQDistCalculator::distance()` already calls `dot_distance()` → `u8::dot()` for Dot distance type. Replacing the `Dot<u8>` body is the only change needed. ## Benchmarks ### Ryzen 4500 (AVX2, no VNNI) 1M total u8 elements, varying vector dimension. Scalar baseline vs AVX2-dispatched path: | Dimension | Scalar | Dispatch (AVX2) | Speedup | |-----------|--------|-----------------|---------| | 128 | 51.02 µs | 58.25 µs | 0.88x (dispatch overhead dominates) | | 256 | 44.96 µs | 38.62 µs | **1.16x** | | 512 | 42.82 µs | 28.27 µs | **1.51x** | | 1024 | 41.00 µs | 25.17 µs | **1.63x** | AVX2 delivers up to 1.63x throughput at dim=1024. At dim=128 the `OnceLock` dispatch and AVX2 loop setup overhead exceeds the SIMD gains on short vectors. AVX-512 VNNI (Ice Lake+ / Zen 4+) is expected to show larger gains with 64 elements/iter. ### Apple M4 (ARM64, scalar fallback) On ARM64 the dispatch falls back to scalar, so both paths perform identically (~13 µs at dim=1024). A follow-up ARM NEON `UDOT` path would bring SIMD gains to Apple Silicon. ### Out of scope (follow-up) - L2/Cosine u8 SIMD optimization (different kernel: `Σ(a-b)²`) - Native `VPDPBUUD` (unsigned×unsigned, Sierra Forest+) — too new for stable Rust - ARM NEON `UDOT` path - Precomputed norms for SQ L2/Cosine (requires storage format change) ## Test plan - [x] Unit tests: random inputs across 18 vector sizes (0-4097), boundary values (all 0s, all 255s, alternating), one-sided zeros, all-ones patterns - [x] Each backend tested independently against scalar reference (with `#[cfg]` guards for missing CPU features) - [x] Existing `dot` tests continue to pass (9/9) - [x] `cargo clippy -p lance-linalg --tests --benches -- -D warnings` clean - [x] Benchmark on x86_64 with AVX2: `cargo bench --bench dot -p lance-linalg -- "Dot\(u8"` 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c913ff8 commit 8f479db

4 files changed

Lines changed: 291 additions & 5 deletions

File tree

rust/lance-linalg/benches/dot.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,40 @@ fn bench_distance(c: &mut Criterion) {
107107
});
108108
});
109109

110+
// u8 dot product benchmarks: scalar baseline vs SIMD dispatch
111+
{
112+
use lance_linalg::distance::dot_u8::{dot_u8, dot_u8_scalar};
113+
114+
for &dim in &[128, 256, 512, 1024] {
115+
let num_vectors = 1024 * 1024 / dim; // ~1M elements total
116+
let mut rng = rand::rng();
117+
let key_u8: Vec<u8> = (0..dim).map(|_| rng.random()).collect();
118+
let target_u8: Vec<u8> = (0..num_vectors * dim).map(|_| rng.random()).collect();
119+
120+
c.bench_function(&format!("Dot(u8, scalar, dim={dim})"), |b| {
121+
b.iter(|| {
122+
black_box(
123+
target_u8
124+
.chunks(dim)
125+
.map(|y| dot_u8_scalar(key_u8.as_slice(), y))
126+
.collect::<Vec<_>>(),
127+
)
128+
});
129+
});
130+
131+
c.bench_function(&format!("Dot(u8, dispatch, dim={dim})"), |b| {
132+
b.iter(|| {
133+
black_box(
134+
target_u8
135+
.chunks(dim)
136+
.map(|y| dot_u8(key_u8.as_slice(), y))
137+
.collect::<Vec<_>>(),
138+
)
139+
});
140+
});
141+
}
142+
}
143+
110144
run_bench::<Float32Type>(c);
111145
c.bench_function("Dot(f32, SIMD)", |b| {
112146
let key = generate_random_array_with_seed::<Float32Type>(DIMENSION, [0; 32]);

rust/lance-linalg/src/distance.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use arrow_schema::{ArrowError, DataType};
1919
pub mod cosine;
2020
pub mod cosine_u8;
2121
pub mod dot;
22+
pub mod dot_u8;
2223
pub mod hamming;
2324
pub mod l2;
2425
pub mod l2_u8;

rust/lance-linalg/src/distance/dot.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,7 @@ fn dot_f64_simd(x: &[f64], y: &[f64]) -> f32 {
189189
impl Dot for u8 {
190190
#[inline]
191191
fn dot(x: &[Self], y: &[Self]) -> f32 {
192-
// TODO: this is not optimized for auto vectorization yet.
193-
x.iter()
194-
.zip(y.iter())
195-
.map(|(&x_i, &y_i)| x_i as u32 * y_i as u32)
196-
.sum::<u32>() as f32
192+
super::dot_u8::dot_u8(x, y) as f32
197193
}
198194
}
199195

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright The Lance Authors
3+
4+
//! Unsigned int8 dot product with runtime-dispatched SIMD backends.
5+
//!
6+
//! Used by Scalar Quantization (SQ) distance computation. SQ stores each
7+
//! vector dimension as a u8 after linearly mapping [min, max] → [0, 255].
8+
//! Distance computation between SQ-encoded vectors reduces to a u8 × u8
9+
//! dot product plus precomputed per-vector scalar terms.
10+
//!
11+
//! Backends (selected at runtime, best available wins):
12+
//! 1. scalar — portable reference, also used for tails
13+
//! 2. avx2 — VPMADDWD on u16-widened halves, 32 elements/iter
14+
//! 3. avx512vnni — VPDPBUSD with XOR-0x80 bias trick, 64 elements/iter
15+
//!
16+
//! ## The VNNI bias trick
17+
//!
18+
//! VPDPBUSD expects one unsigned and one signed operand, but SQ vectors
19+
//! are u8 × u8. We bias `b` into the signed domain via XOR 0x80 (equivalent
20+
//! to subtracting 128 when reinterpreted as i8), feed `a` directly as
21+
//! unsigned, and correct by adding 128·Σa at the end:
22+
//!
23+
//! DPBUSD(a, b ⊕ 0x80) = Σ a·(b − 128) = Σ a·b − 128·Σa
24+
//!
25+
//! The Σa term uses VPSADBW, which dispatches to port 5 while VPDPBUSD
26+
//! runs on port 0 on Intel. The two instructions execute in parallel,
27+
//! making the correction effectively free.
28+
29+
use std::sync::OnceLock;
30+
31+
/// Portable scalar u8 dot product, also used for SIMD tail elements.
32+
#[inline]
33+
pub fn dot_u8_scalar(a: &[u8], b: &[u8]) -> u32 {
34+
debug_assert_eq!(a.len(), b.len());
35+
a.iter()
36+
.zip(b.iter())
37+
.map(|(&x, &y)| x as u32 * y as u32)
38+
.sum()
39+
}
40+
41+
#[cfg(target_arch = "x86_64")]
42+
mod x86 {
43+
use std::arch::x86_64::*;
44+
45+
/// AVX2 path: zero-extend u8→u16, then VPMADDWD. 32 elements/iter.
46+
#[target_feature(enable = "avx2")]
47+
pub unsafe fn dot_u8_avx2(a: &[u8], b: &[u8]) -> u32 {
48+
debug_assert_eq!(a.len(), b.len());
49+
let n = a.len();
50+
let mut acc = _mm256_setzero_si256();
51+
let mut i = 0usize;
52+
53+
while i + 32 <= n {
54+
let av = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
55+
let bv = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
56+
57+
// Zero-extend each 128-bit half to 16 × u16. Values ≤ 255 fit
58+
// in i16 as positive, so VPMADDWD gives correct results.
59+
let a_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(av));
60+
let a_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(av, 1));
61+
let b_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(bv));
62+
let b_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(bv, 1));
63+
64+
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(a_lo, b_lo));
65+
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(a_hi, b_hi));
66+
i += 32;
67+
}
68+
69+
let lo128 = _mm256_castsi256_si128(acc);
70+
let hi128 = _mm256_extracti128_si256(acc, 1);
71+
let mut sum128 = _mm_add_epi32(lo128, hi128);
72+
sum128 = _mm_hadd_epi32(sum128, sum128);
73+
sum128 = _mm_hadd_epi32(sum128, sum128);
74+
let mut result = _mm_cvtsi128_si32(sum128) as u32;
75+
76+
while i < n {
77+
result += a[i] as u32 * b[i] as u32;
78+
i += 1;
79+
}
80+
result
81+
}
82+
83+
/// AVX-512 VNNI path (Ice Lake+, Zen 4+). 64 elements/iter.
84+
///
85+
/// VPDPBUSD expects (unsigned, signed) operands but SQ stores u8×u8.
86+
/// We XOR b with 0x80 to map it to i8, then correct: result + 128·Σa.
87+
/// The Σa term (VPSADBW, port 5) runs in parallel with VPDPBUSD (port 0).
88+
#[target_feature(enable = "avx512f,avx512bw,avx512vnni")]
89+
pub unsafe fn dot_u8_avx512_vnni(a: &[u8], b: &[u8]) -> u32 {
90+
debug_assert_eq!(a.len(), b.len());
91+
let n = a.len();
92+
93+
let mut acc_dot = _mm512_setzero_si512();
94+
let mut acc_suma = _mm512_setzero_si512();
95+
let sign_flip = _mm512_set1_epi8(0x80u8 as i8);
96+
let zeros = _mm512_setzero_si512();
97+
let mut i = 0usize;
98+
99+
while i + 64 <= n {
100+
let av = _mm512_loadu_si512(a.as_ptr().add(i) as *const __m512i);
101+
let bv = _mm512_loadu_si512(b.as_ptr().add(i) as *const __m512i);
102+
let b_biased = _mm512_xor_si512(bv, sign_flip);
103+
acc_dot = _mm512_dpbusd_epi32(acc_dot, av, b_biased);
104+
acc_suma = _mm512_add_epi64(acc_suma, _mm512_sad_epu8(av, zeros));
105+
i += 64;
106+
}
107+
108+
let biased_dot = _mm512_reduce_add_epi32(acc_dot);
109+
let sum_a = _mm512_reduce_add_epi64(acc_suma);
110+
let mut result = (biased_dot as i64 + 128 * sum_a) as u32;
111+
112+
while i < n {
113+
result += a[i] as u32 * b[i] as u32;
114+
i += 1;
115+
}
116+
result
117+
}
118+
}
119+
120+
type DotU8Fn = fn(&[u8], &[u8]) -> u32;
121+
122+
static DISPATCH: OnceLock<DotU8Fn> = OnceLock::new();
123+
124+
fn select_backend() -> DotU8Fn {
125+
#[cfg(target_arch = "x86_64")]
126+
{
127+
if is_x86_feature_detected!("avx512f")
128+
&& is_x86_feature_detected!("avx512bw")
129+
&& is_x86_feature_detected!("avx512vnni")
130+
{
131+
return |a, b| unsafe { x86::dot_u8_avx512_vnni(a, b) };
132+
}
133+
134+
if is_x86_feature_detected!("avx2") {
135+
return |a, b| unsafe { x86::dot_u8_avx2(a, b) };
136+
}
137+
}
138+
139+
dot_u8_scalar
140+
}
141+
142+
/// Dispatched u8 dot product, selecting the best available SIMD backend.
143+
#[inline]
144+
pub fn dot_u8(a: &[u8], b: &[u8]) -> u32 {
145+
(DISPATCH.get_or_init(select_backend))(a, b)
146+
}
147+
148+
#[cfg(test)]
149+
mod tests {
150+
use super::*;
151+
152+
fn fill_random(buf: &mut [u8], seed: &mut u32) {
153+
for slot in buf.iter_mut() {
154+
*seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
155+
*slot = (*seed >> 16) as u8;
156+
}
157+
}
158+
159+
const SIZES: &[usize] = &[
160+
0, 1, 7, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 1024, 4096, 4097,
161+
];
162+
163+
fn check_all_backends(a: &[u8], b: &[u8], case: &str) {
164+
let reference = dot_u8_scalar(a, b);
165+
166+
#[cfg(target_arch = "x86_64")]
167+
{
168+
if is_x86_feature_detected!("avx2") {
169+
let got = unsafe { x86::dot_u8_avx2(a, b) };
170+
assert_eq!(got, reference, "avx2 [{case}] n={}", a.len());
171+
}
172+
173+
if is_x86_feature_detected!("avx512f")
174+
&& is_x86_feature_detected!("avx512bw")
175+
&& is_x86_feature_detected!("avx512vnni")
176+
{
177+
let got = unsafe { x86::dot_u8_avx512_vnni(a, b) };
178+
assert_eq!(got, reference, "avx512_vnni [{case}] n={}", a.len());
179+
}
180+
}
181+
182+
assert_eq!(dot_u8(a, b), reference, "dispatch [{case}] n={}", a.len());
183+
}
184+
185+
#[test]
186+
fn random_inputs_across_sizes_and_seeds() {
187+
let mut a = vec![0u8; 4097];
188+
let mut b = vec![0u8; 4097];
189+
190+
for seed_idx in 0..4u32 {
191+
let mut seed = 0xC0FFEE_u32.wrapping_add(seed_idx.wrapping_mul(7919));
192+
for &n in SIZES {
193+
fill_random(&mut a[..n], &mut seed);
194+
fill_random(&mut b[..n], &mut seed);
195+
check_all_backends(&a[..n], &b[..n], "random");
196+
}
197+
}
198+
}
199+
200+
#[test]
201+
fn boundary_values() {
202+
let mut a = vec![0u8; 4097];
203+
let mut b = vec![0u8; 4097];
204+
205+
for &n in SIZES {
206+
a[..n].fill(u8::MAX);
207+
b[..n].fill(u8::MAX);
208+
check_all_backends(&a[..n], &b[..n], "max*max");
209+
210+
a[..n].fill(u8::MAX);
211+
b[..n].fill(0);
212+
check_all_backends(&a[..n], &b[..n], "max*0");
213+
214+
a[..n].fill(0);
215+
b[..n].fill(u8::MAX);
216+
check_all_backends(&a[..n], &b[..n], "0*max");
217+
218+
for i in 0..n {
219+
a[i] = if i & 1 == 0 { 0 } else { u8::MAX };
220+
b[i] = if i & 1 == 0 { u8::MAX } else { 0 };
221+
}
222+
check_all_backends(&a[..n], &b[..n], "alt 0/max");
223+
}
224+
}
225+
226+
#[test]
227+
fn one_sided_zeros() {
228+
let mut a = vec![0u8; 4097];
229+
let mut b = vec![0u8; 4097];
230+
231+
for &n in SIZES {
232+
let mut seed = 0xDEAD_BEEF_u32;
233+
fill_random(&mut a[..n], &mut seed);
234+
b[..n].fill(0);
235+
check_all_backends(&a[..n], &b[..n], "b=0");
236+
237+
a[..n].fill(0);
238+
fill_random(&mut b[..n], &mut seed);
239+
check_all_backends(&a[..n], &b[..n], "a=0");
240+
}
241+
}
242+
243+
#[test]
244+
fn all_ones_pattern() {
245+
let mut a = vec![0u8; 4097];
246+
let mut b = vec![0u8; 4097];
247+
248+
for &n in SIZES {
249+
a[..n].fill(1);
250+
b[..n].fill(1);
251+
check_all_backends(&a[..n], &b[..n], "1*1");
252+
assert_eq!(dot_u8_scalar(&a[..n], &b[..n]), n as u32);
253+
}
254+
}
255+
}

0 commit comments

Comments
 (0)