Skip to content

Commit 5c83b84

Browse files
perf: add SIMD-accelerated u8 L2 and cosine distance kernels (#6517)
## Summary - Add hand-written AVX2 and AVX-512 VNNI backends for u8 squared L2 distance (`Σ(a-b)²`) in new `l2_u8.rs` - Add fused single-pass u8 cosine distance kernel in new `cosine_u8.rs` — computes `dot(a,b)`, `‖a‖²`, `‖b‖²` simultaneously, halving memory traffic vs the previous 2-3 pass approach - Wire both into the `L2 for u8` and `Cosine for u8` trait impls - Add benchmarks comparing scalar vs SIMD for both kernels ### Algorithmic approach (adapted from [NumKong](https://github.com/ashvardanian/NumKong)) **L2 (AVX2):** Saturating subtraction for `|a-b|`, zero-extend u8→i16, `VPMADDWD(diff, diff)` to square and accumulate into i32. 32 elements/iter. **L2 (AVX-512 VNNI):** Same abs-diff approach with `VPDPWSSD` for fused square-accumulate. 64 elements/iter. **Cosine (AVX2):** Zero-extend both vectors to i16, triple `VPMADDWD` per half (a·b, a·a, b·b). 32 elements/iter, single pass. **Cosine (AVX-512 VNNI):** Same three-accumulator approach with `VPDPWSSD`. 64 elements/iter. Both kernels use `OnceLock`-based runtime CPU dispatch, falling back to portable scalar on non-x86 platforms. ### Benchmarks *1M × 1024-dim u8 vectors.* **x86_64 — AMD Ryzen 5 4500 6-Core (AVX2, no AVX-512)** | Kernel | Scalar | SIMD | Speedup | |--------|--------|------|---------| | L2(u8) | 73.5 ms | 58.2 ms | **1.26x** | | Cosine(u8) | 122.2 ms | 82.1 ms | **1.49x** | L2 auto-vectorization baseline was 91.5 ms, so SIMD is 1.57x faster than that path. **aarch64 — Apple Silicon M3 Max (no AVX2, scalar fallback)** | Kernel | Scalar | SIMD (dispatch) | |--------|--------|-----------------| | L2(u8) | 26.8 ms | 27.3 ms | | Cosine(u8) | 90.1 ms | 90.4 ms | On aarch64 the SIMD path falls through to scalar (no AVX2), so times are identical — confirms no regression on non-x86 platforms. AVX-512 VNNI systems (Ice Lake+, Zen 4+) should see larger gains. ## Test plan - [x] All 11 new tests pass: SIMD backends verified against scalar reference across 18 vector sizes (0–4097), boundary values (0/255), alternating patterns, random seeds - [x] All 63 existing lance-linalg tests pass (no regressions) - [x] Clippy clean, fmt clean - [x] Benchmarked on x86_64 AVX2 (AMD Ryzen 5 4500) — L2 1.26x, Cosine 1.49x faster - [ ] Verify on AVX-512 VNNI system for additional speedup data 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 65ac541 commit 5c83b84

7 files changed

Lines changed: 699 additions & 2 deletions

File tree

rust/lance-linalg/benches/cosine.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use arrow_array::{
88
use criterion::{Criterion, black_box, criterion_group, criterion_main};
99
use lance_arrow::{ArrowFloatType, FloatArray, bfloat16::BFloat16Type};
1010
use lance_linalg::distance::cosine::{Cosine, cosine_distance_batch};
11+
use lance_linalg::distance::cosine_u8::{cosine_u8, cosine_u8_scalar};
1112
use num_traits::Float;
1213

1314
#[cfg(target_os = "linux")]
@@ -76,6 +77,42 @@ fn bench_distance(c: &mut Criterion) {
7677
black_box(cosine_distance_batch(key.values(), target.values(), 8).collect::<Vec<_>>())
7778
})
7879
});
80+
81+
// u8 cosine benchmarks
82+
{
83+
use rand::Rng;
84+
use std::iter::repeat_with;
85+
86+
const DIMENSION: usize = 1024;
87+
const TOTAL: usize = 1024 * 1024;
88+
let mut rng = rand::rng();
89+
let key_u8: Vec<u8> = repeat_with(|| rng.random()).take(DIMENSION).collect();
90+
let target_u8: Vec<u8> = repeat_with(|| rng.random())
91+
.take(TOTAL * DIMENSION)
92+
.collect();
93+
94+
c.bench_function("Cosine(u8, scalar)", |b| {
95+
b.iter(|| {
96+
black_box(
97+
target_u8
98+
.chunks_exact(DIMENSION)
99+
.map(|tgt| cosine_u8_scalar(&key_u8, tgt))
100+
.fold(0.0, |acc: f32, v| acc + v),
101+
);
102+
});
103+
});
104+
105+
c.bench_function("Cosine(u8, SIMD)", |b| {
106+
b.iter(|| {
107+
black_box(
108+
target_u8
109+
.chunks_exact(DIMENSION)
110+
.map(|tgt| cosine_u8(&key_u8, tgt))
111+
.fold(0.0, |acc: f32, v| acc + v),
112+
);
113+
});
114+
});
115+
}
79116
}
80117

81118
#[cfg(target_os = "linux")]

rust/lance-linalg/benches/l2.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use rand::Rng;
1515
use pprof::criterion::{Output, PProfProfiler};
1616

1717
use lance_arrow::{ArrowFloatType, FloatArray};
18+
use lance_linalg::distance::l2_u8::l2_u8;
1819
use lance_linalg::distance::{L2, l2::l2, l2_distance_batch, l2_distance_uint_scalar};
1920
use lance_testing::datagen::generate_random_array_with_seed;
2021

@@ -157,6 +158,17 @@ fn bench_uint_distance(c: &mut Criterion) {
157158
);
158159
});
159160
});
161+
162+
c.bench_function("L2(u8, SIMD)", |b| {
163+
b.iter(|| {
164+
black_box(
165+
target
166+
.chunks_exact(DIMENSION)
167+
.map(|tgt| l2_u8(&key, tgt) as f32)
168+
.fold(0.0, |acc, v| acc + v),
169+
);
170+
});
171+
});
160172
}
161173

162174
#[cfg(target_os = "linux")]

rust/lance-linalg/src/distance.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, L
1717
use arrow_schema::{ArrowError, DataType};
1818

1919
pub mod cosine;
20+
pub mod cosine_u8;
2021
pub mod dot;
2122
pub mod hamming;
2223
pub mod l2;
24+
pub mod l2_u8;
2325
pub mod norm_l2;
2426

2527
pub use cosine::*;

rust/lance-linalg/src/distance/cosine.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ pub trait Cosine: Dot + Normalize {
6565
}
6666
}
6767

68-
impl Cosine for u8 {}
68+
impl Cosine for u8 {
69+
#[inline]
70+
fn cosine(x: &[Self], other: &[Self]) -> f32 {
71+
super::cosine_u8::cosine_u8(x, other)
72+
}
73+
}
6974

7075
impl Cosine for bf16 {}
7176

0 commit comments

Comments
 (0)