Skip to content

Commit c913ff8

Browse files
perf: add explicit SIMD types and distance kernels for f64 (#6540)
## Summary - Adds `f64x4` and `f64x8` SIMD types to `lance-linalg` with support for x86_64 (AVX2/AVX-512), aarch64 (NEON), and loongarch64 (LASX) - Replaces auto-vectorization-dependent f64 distance functions with explicit SIMD using two-level unrolling (f64x8 + f64x4 + scalar tail) - Updates norm_l2, dot, L2, and cosine distance for f64 ## Benchmark Results (Apple M-series, aarch64 NEON) 1M vectors × 1024 dimensions: | Benchmark | Before | After | Change | |-----------|--------|-------|--------| | NormL2(f64, auto-vec) | 117.76 ms | 116.04 ms | ~same | | NormL2(f64, SIMD) | N/A (TODO) | 119.16 ms | new | | Dot(f64, auto-vec) | 129.36 ms | 130.23 ms | ~same | | L2(f64, auto-vec) | 132.53 ms | 135.15 ms | ~same | | **Cosine(f64, auto-vec)** | **202.52 ms** | **139.23 ms** | **-31.4%** | The biggest win is **cosine distance**, which previously had an empty `impl Cosine for f64 {}` falling back to the scalar path. The explicit SIMD implementation is **31% faster**. For norm_l2, dot, and L2, LLVM's auto-vectorization with the LANES=8 hint was already producing good code on this platform. The explicit SIMD ensures consistent performance across compilers and platforms rather than relying on fragile auto-vectorization hints. ## Test plan - [x] All 59 lance-linalg tests pass - [x] Clippy clean (`-D warnings`) - [x] `cargo fmt` clean - [ ] CI passes on all platforms 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent accef66 commit c913ff8

7 files changed

Lines changed: 1044 additions & 6 deletions

File tree

rust/lance-linalg/benches/norm_l2.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use num_traits::Float;
1313
use rand::Rng;
1414

1515
use lance_arrow::{ArrowFloatType, FloatArray, bfloat16::BFloat16Type};
16-
use lance_linalg::distance::{norm_l2, norm_l2_impl};
16+
use lance_linalg::distance::{norm_l2, norm_l2_f64_simd, norm_l2_impl};
1717
use lance_testing::datagen::generate_random_array_with_seed;
1818

1919
#[cfg(target_os = "linux")]
@@ -106,7 +106,7 @@ fn bench_distance(c: &mut Criterion) {
106106
c,
107107
target.as_slice(),
108108
norm_l2_impl::<f64, f32, 8>,
109-
None, // TODO: implement SIMD for f64
109+
Some(norm_l2_f64_simd),
110110
);
111111
}
112112

rust/lance-linalg/src/distance/cosine.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,47 @@ impl Cosine for f32 {
230230
}
231231
}
232232

233-
impl Cosine for f64 {}
233+
impl Cosine for f64 {
234+
#[inline]
235+
fn cosine_fast(x: &[Self], x_norm: f32, y: &[Self]) -> f32 {
236+
use crate::simd::f64::{f64x4, f64x8};
237+
use crate::simd::{FloatSimd, SIMD};
238+
239+
let dim = x.len();
240+
let unrolled_len = dim / 8 * 8;
241+
let mut y_norm8 = f64x8::zeros();
242+
let mut xy8 = f64x8::zeros();
243+
for i in (0..unrolled_len).step_by(8) {
244+
unsafe {
245+
let xv = f64x8::load_unaligned(x.as_ptr().add(i));
246+
let yv = f64x8::load_unaligned(y.as_ptr().add(i));
247+
xy8.multiply_add(xv, yv);
248+
y_norm8.multiply_add(yv, yv);
249+
}
250+
}
251+
let aligned_len = dim / 4 * 4;
252+
let mut y_norm4 = f64x4::zeros();
253+
let mut xy4 = f64x4::zeros();
254+
for i in (unrolled_len..aligned_len).step_by(4) {
255+
unsafe {
256+
let xv = f64x4::load_unaligned(x.as_ptr().add(i));
257+
let yv = f64x4::load_unaligned(y.as_ptr().add(i));
258+
xy4.multiply_add(xv, yv);
259+
y_norm4.multiply_add(yv, yv);
260+
}
261+
}
262+
let tail_y_norm: Self = y[aligned_len..].iter().map(|&v| v * v).sum();
263+
let tail_xy: Self = x[aligned_len..]
264+
.iter()
265+
.zip(y[aligned_len..].iter())
266+
.map(|(&a, &b)| a * b)
267+
.sum();
268+
269+
let y_norm_sq = (y_norm8.reduce_sum() + y_norm4.reduce_sum() + tail_y_norm) as f32;
270+
let xy = (xy8.reduce_sum() + xy4.reduce_sum() + tail_xy) as f32;
271+
1.0 - xy / x_norm / y_norm_sq.sqrt()
272+
}
273+
}
234274

235275
/// Fallback non-SIMD implementation
236276
#[inline]

rust/lance-linalg/src/distance/dot.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,47 @@ impl Dot for f32 {
145145
impl Dot for f64 {
146146
#[inline]
147147
fn dot(x: &[Self], y: &[Self]) -> f32 {
148-
dot_scalar::<Self, Self, 8>(x, y) as f32
148+
dot_f64_simd(x, y)
149149
}
150150
}
151151

152+
/// Explicit SIMD dot product for f64.
153+
#[inline]
154+
fn dot_f64_simd(x: &[f64], y: &[f64]) -> f32 {
155+
use crate::simd::f64::{f64x4, f64x8};
156+
use crate::simd::{FloatSimd, SIMD};
157+
158+
let dim = x.len();
159+
let unrolled_len = dim / 8 * 8;
160+
161+
let mut acc8 = f64x8::zeros();
162+
for i in (0..unrolled_len).step_by(8) {
163+
unsafe {
164+
let a = f64x8::load_unaligned(x.as_ptr().add(i));
165+
let b = f64x8::load_unaligned(y.as_ptr().add(i));
166+
acc8.multiply_add(a, b);
167+
}
168+
}
169+
170+
let aligned_len = dim / 4 * 4;
171+
let mut acc4 = f64x4::zeros();
172+
for i in (unrolled_len..aligned_len).step_by(4) {
173+
unsafe {
174+
let a = f64x4::load_unaligned(x.as_ptr().add(i));
175+
let b = f64x4::load_unaligned(y.as_ptr().add(i));
176+
acc4.multiply_add(a, b);
177+
}
178+
}
179+
180+
let tail: f64 = x[aligned_len..]
181+
.iter()
182+
.zip(y[aligned_len..].iter())
183+
.map(|(&a, &b)| a * b)
184+
.sum();
185+
186+
(acc8.reduce_sum() + acc4.reduce_sum() + tail) as f32
187+
}
188+
152189
impl Dot for u8 {
153190
#[inline]
154191
fn dot(x: &[Self], y: &[Self]) -> f32 {

rust/lance-linalg/src/distance/l2.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,52 @@ impl L2 for f32 {
170170
impl L2 for f64 {
171171
#[inline]
172172
fn l2(x: &[Self], y: &[Self]) -> f32 {
173-
l2_scalar::<Self, Self, 8>(x, y) as f32
173+
l2_f64_simd(x, y)
174174
}
175175
}
176176

177+
/// Explicit SIMD L2 distance for f64.
178+
#[inline]
179+
fn l2_f64_simd(x: &[f64], y: &[f64]) -> f32 {
180+
use crate::simd::f64::{f64x4, f64x8};
181+
use crate::simd::{FloatSimd, SIMD};
182+
183+
let dim = x.len();
184+
let unrolled_len = dim / 8 * 8;
185+
186+
let mut acc8 = f64x8::zeros();
187+
for i in (0..unrolled_len).step_by(8) {
188+
unsafe {
189+
let a = f64x8::load_unaligned(x.as_ptr().add(i));
190+
let b = f64x8::load_unaligned(y.as_ptr().add(i));
191+
let diff = a - b;
192+
acc8.multiply_add(diff, diff);
193+
}
194+
}
195+
196+
let aligned_len = dim / 4 * 4;
197+
let mut acc4 = f64x4::zeros();
198+
for i in (unrolled_len..aligned_len).step_by(4) {
199+
unsafe {
200+
let a = f64x4::load_unaligned(x.as_ptr().add(i));
201+
let b = f64x4::load_unaligned(y.as_ptr().add(i));
202+
let diff = a - b;
203+
acc4.multiply_add(diff, diff);
204+
}
205+
}
206+
207+
let tail: f64 = x[aligned_len..]
208+
.iter()
209+
.zip(y[aligned_len..].iter())
210+
.map(|(&a, &b)| {
211+
let diff = a - b;
212+
diff * diff
213+
})
214+
.sum();
215+
216+
(acc8.reduce_sum() + acc4.reduce_sum() + tail) as f32
217+
}
218+
177219
/// Accumulate squared differences for one dimension into per-target results.
178220
///
179221
/// Separated into its own function so that LLVM sees `row` and `result`

rust/lance-linalg/src/distance/norm_l2.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,42 @@ impl Normalize for f32 {
9797
impl Normalize for f64 {
9898
#[inline]
9999
fn norm_l2(vector: &[Self]) -> f32 {
100-
norm_l2_impl::<Self, Self, 8>(vector) as f32
100+
norm_l2_f64_simd(vector)
101101
}
102102
}
103103

104+
/// Explicit SIMD implementation of L2 norm for f64.
105+
///
106+
/// Two-level unrolling: f64x8 main loop, f64x4 remainder, scalar tail.
107+
#[inline]
108+
pub fn norm_l2_f64_simd(vector: &[f64]) -> f32 {
109+
use crate::simd::f64::{f64x4, f64x8};
110+
use crate::simd::{FloatSimd, SIMD};
111+
112+
let dim = vector.len();
113+
let unrolled_len = dim / 8 * 8;
114+
115+
let mut acc8 = f64x8::zeros();
116+
for i in (0..unrolled_len).step_by(8) {
117+
unsafe {
118+
let v = f64x8::load_unaligned(vector.as_ptr().add(i));
119+
acc8.multiply_add(v, v);
120+
}
121+
}
122+
123+
let aligned_len = dim / 4 * 4;
124+
let mut acc4 = f64x4::zeros();
125+
for i in (unrolled_len..aligned_len).step_by(4) {
126+
unsafe {
127+
let v = f64x4::load_unaligned(vector.as_ptr().add(i));
128+
acc4.multiply_add(v, v);
129+
}
130+
}
131+
132+
let tail: f64 = vector[aligned_len..].iter().map(|&v| v * v).sum();
133+
(acc8.reduce_sum() + acc4.reduce_sum() + tail).sqrt() as f32
134+
}
135+
104136
/// NOTE: this is only pub for benchmarking purposes
105137
#[inline]
106138
pub fn norm_l2_impl<

rust/lance-linalg/src/simd.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::ops::{Add, AddAssign, Mul, Sub, SubAssign};
1616

1717
pub mod dist_table;
1818
pub mod f32;
19+
pub mod f64;
1920
pub mod i32;
2021
pub mod u8;
2122

0 commit comments

Comments
 (0)