Skip to content

Commit d0124ed

Browse files
perf: add SIMD kernels for bf16 distance functions (#6510)
## Summary - Replaces the external `numkong` dependency with in-tree C kernels for **bf16 distance computation** (dot product, L2, cosine, norm_l2) - Follows the existing f16 kernel pattern: C source compiled via `build.rs` with per-architecture flags, runtime CPU dispatch via `SIMD_SUPPORT` - Kernels are only enabled when the CPU supports the required instructions (NEON on aarch64, AVX2/AVX-512 on x86_64, LSX/LASX on loongarch64), with scalar fallback otherwise - Gated behind the existing `fp16kernels` feature flag ## Benchmark Results Tested on two platforms with 1M x 1024-dim vectors: ### Apple Silicon (M-series, NEON) | Benchmark | Before (scalar) | After (C kernel) | Change | |-----------|-----------------|-------------------|--------| | **Dot(bf16)** | 144 ms | 55 ms | **2.6x faster** | | **NormL2(bf16)** | 90 ms | 36 ms | **2.5x faster** | ### AMD Ryzen 5 4500 (Zen 2, AVX2) | Benchmark | Before (scalar) | After (C kernel) | Change | |-----------|-----------------|-------------------|--------| | **Dot(bf16)** | 578 ms | 363 ms | **1.6x faster** (−37%) | | **NormL2(bf16)** | 365 ms | 207 ms | **1.8x faster** (−43%) | ### Why the approach works BF16-to-f32 conversion is a simple left-shift by 16 bits. The C kernels compiled with architecture-specific flags (`-march=haswell`, `-mtune=apple-m1`, etc.) plus `-ffast-math` and vectorization pragmas give the compiler more freedom to emit tight SIMD code than LLVM gets from the Rust scalar loops. ARM benefits more because the baseline Rust auto-vectorization was weaker there. ## Files Changed - **New**: `rust/lance-linalg/src/simd/bf16.c` — C kernels for dot, L2, cosine, norm_l2 - `rust/lance-linalg/build.rs` — compile bf16.c for each architecture - `rust/lance-linalg/src/distance/{dot,l2,cosine,norm_l2}.rs` — runtime SIMD dispatch for bf16 - `rust/lance-linalg/Cargo.toml` — removed `numkong` dependency and feature - `rust/lance-linalg/benches/{dot,l2,cosine}.rs` — removed numkong benchmark sections - **Deleted**: `scripts/bench_numkong.sh` ## Test plan - [x] `cargo test -p lance-linalg --features fp16kernels` — all bf16 tests pass (kernel path) - [x] `cargo test -p lance-linalg` — all bf16 tests pass (scalar fallback) - [x] `cargo clippy -p lance-linalg --features fp16kernels --tests --benches -- -D warnings` — clean - [x] Benchmarked on Apple Silicon (ARM NEON) - [x] Benchmarked on AMD Ryzen 5 4500 (x86_64 AVX2) - To reproduce: ```bash git checkout HEAD~1 TARGET_TIME=3 cargo bench -p lance-linalg --features fp16kernels --bench dot -- --save-baseline before "bf16" TARGET_TIME=3 cargo bench -p lance-linalg --features fp16kernels --bench norm_l2 -- --save-baseline before "bf16" git checkout - TARGET_TIME=3 cargo bench -p lance-linalg --features fp16kernels --bench dot -- --baseline before "bf16" TARGET_TIME=3 cargo bench -p lance-linalg --features fp16kernels --bench norm_l2 -- --baseline before "bf16" ``` 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8f479db commit d0124ed

7 files changed

Lines changed: 322 additions & 6 deletions

File tree

rust/lance-linalg/build.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ fn main() -> Result<(), String> {
1919
println!("cargo::rustc-check-cfg=cfg(kernel_support, values(\"avx512\"))");
2020

2121
println!("cargo:rerun-if-changed=src/simd/f16.c");
22+
println!("cargo:rerun-if-changed=src/simd/bf16.c");
2223
println!("cargo:rerun-if-changed=src/simd/dist_table.c");
2324

2425
// Important: we don't use `cfg!(target_arch)` here because that is the target_arch
@@ -37,13 +38,16 @@ fn main() -> Result<(), String> {
3738
if target_arch == "aarch64" && target_os == "macos" {
3839
// Build a version with NEON
3940
build_f16_with_flags("neon", &["-mtune=apple-m1"]).unwrap();
41+
build_bf16_with_flags("neon", &["-mtune=apple-m1"]).unwrap();
4042
} else if target_arch == "aarch64" && target_os == "ios" {
4143
// Build version with NEON
4244
// A13 bionic is the earliest supported iOS SOC
4345
build_f16_with_flags("neon", &["-mtune=apple-a13"]).unwrap();
46+
build_bf16_with_flags("neon", &["-mtune=apple-a13"]).unwrap();
4447
} else if target_arch == "aarch64" && (target_os == "linux" || target_os == "android") {
4548
// Build a version with NEON
4649
build_f16_with_flags("neon", &["-march=armv8.2-a+fp16"]).unwrap();
50+
build_bf16_with_flags("neon", &["-march=armv8.2-a+fp16"]).unwrap();
4751
} else if target_arch == "x86_64" {
4852
// Build a version with AVX512
4953
if let Err(err) = build_f16_with_flags("avx512", &["-march=sapphirerapids", "-mavx512fp16"])
@@ -59,6 +63,17 @@ fn main() -> Result<(), String> {
5963
// generated the AVX512 version of the f16 kernels.
6064
println!("cargo:rustc-cfg=kernel_support=\"avx512\"");
6165
};
66+
// Build AVX-512 bf16 kernels (sapphirerapids has native vdpbf16ps)
67+
if let Err(err) =
68+
build_bf16_with_flags("avx512", &["-march=sapphirerapids", "-mavx512fp16"])
69+
{
70+
println!(
71+
"cargo:warning=Skipping build of AVX-512 bf16 kernels. Error: {}",
72+
err
73+
);
74+
} else {
75+
println!("cargo:rustc-cfg=kernel_support=\"avx512\"");
76+
};
6277
if let Err(err) = build_dist_table_with_flags("avx512", &["-march=native"]) {
6378
println!(
6479
"cargo:warning=Skipping build of AVX-512 dist_table. Error: {}",
@@ -77,11 +92,20 @@ fn main() -> Result<(), String> {
7792
err
7893
));
7994
};
95+
// Build AVX2 bf16 kernels (bf16-to-f32 is just a shift, auto-vectorizes well)
96+
if let Err(err) = build_bf16_with_flags("avx2", &["-march=haswell"]) {
97+
return Err(format!(
98+
"Unable to build AVX2 bf16 kernels. Received error: {}",
99+
err
100+
));
101+
};
80102
// There is no SSE instruction set for f16 -> f32 float conversion
81103
} else if target_arch == "loongarch64" {
82104
// Build a version with LSX and LASX
83105
build_f16_with_flags("lsx", &["-mlsx"]).unwrap();
84106
build_f16_with_flags("lasx", &["-mlasx"]).unwrap();
107+
build_bf16_with_flags("lsx", &["-mlsx"]).unwrap();
108+
build_bf16_with_flags("lasx", &["-mlasx"]).unwrap();
85109
} else {
86110
// Only error if fp16kernels was explicitly requested on unsupported platform.
87111
// This allows builds on iOS, Android, etc. when the feature is disabled.
@@ -128,6 +152,32 @@ fn build_f16_with_flags(suffix: &str, flags: &[&str]) -> Result<(), cc::Error> {
128152
builder.try_compile(&format!("f16_{}", suffix))
129153
}
130154

155+
fn build_bf16_with_flags(suffix: &str, flags: &[&str]) -> Result<(), cc::Error> {
156+
if cfg!(not(feature = "fp16kernels")) {
157+
println!(
158+
"cargo:warning=fp16kernels feature is not enabled, skipping build of bf16 kernels"
159+
);
160+
return Ok(());
161+
}
162+
163+
let mut builder = cc::Build::new();
164+
builder
165+
.std("c17")
166+
.file("src/simd/bf16.c")
167+
.flag("-ffast-math")
168+
.flag("-funroll-loops")
169+
.flag("-O3")
170+
.flag("-Wall")
171+
.flag("-Wextra")
172+
.flag(format!("-DSUFFIX=_{}", suffix).as_str());
173+
174+
for flag in flags {
175+
builder.flag(flag);
176+
}
177+
178+
builder.try_compile(&format!("bf16_{}", suffix))
179+
}
180+
131181
fn build_dist_table_with_flags(suffix: &str, flags: &[&str]) -> Result<(), cc::Error> {
132182
let mut builder = cc::Build::new();
133183
builder

rust/lance-linalg/src/distance/cosine.rs

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,65 @@ impl Cosine for u8 {
7272
}
7373
}
7474

75-
impl Cosine for bf16 {}
75+
#[cfg(feature = "fp16kernels")]
76+
mod bf16_kernel {
77+
use half::bf16;
78+
79+
// These are the `cosine_bf16` function in bf16.c. Our build.rs script compiles
80+
// a version of this file for each SIMD level with different suffixes.
81+
unsafe extern "C" {
82+
#[cfg(target_arch = "aarch64")]
83+
pub fn cosine_bf16_neon(x: *const bf16, x_norm: f32, y: *const bf16, dimension: u32)
84+
-> f32;
85+
#[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
86+
pub fn cosine_bf16_avx512(
87+
x: *const bf16,
88+
x_norm: f32,
89+
y: *const bf16,
90+
dimension: u32,
91+
) -> f32;
92+
#[cfg(target_arch = "x86_64")]
93+
pub fn cosine_bf16_avx2(x: *const bf16, x_norm: f32, y: *const bf16, dimension: u32)
94+
-> f32;
95+
#[cfg(target_arch = "loongarch64")]
96+
pub fn cosine_bf16_lsx(x: *const bf16, x_norm: f32, y: *const bf16, dimension: u32) -> f32;
97+
#[cfg(target_arch = "loongarch64")]
98+
pub fn cosine_bf16_lasx(x: *const bf16, x_norm: f32, y: *const bf16, dimension: u32)
99+
-> f32;
100+
}
101+
}
102+
103+
impl Cosine for bf16 {
104+
fn cosine_fast(x: &[Self], x_norm: f32, y: &[Self]) -> f32 {
105+
match *SIMD_SUPPORT {
106+
#[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
107+
SimdSupport::Neon => unsafe {
108+
bf16_kernel::cosine_bf16_neon(x.as_ptr(), x_norm, y.as_ptr(), y.len() as u32)
109+
},
110+
#[cfg(all(
111+
feature = "fp16kernels",
112+
kernel_support = "avx512",
113+
target_arch = "x86_64"
114+
))]
115+
SimdSupport::Avx512FP16 => unsafe {
116+
bf16_kernel::cosine_bf16_avx512(x.as_ptr(), x_norm, y.as_ptr(), y.len() as u32)
117+
},
118+
#[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
119+
SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe {
120+
bf16_kernel::cosine_bf16_avx2(x.as_ptr(), x_norm, y.as_ptr(), y.len() as u32)
121+
},
122+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
123+
SimdSupport::Lasx => unsafe {
124+
bf16_kernel::cosine_bf16_lasx(x.as_ptr(), x_norm, y.as_ptr(), y.len() as u32)
125+
},
126+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
127+
SimdSupport::Lsx => unsafe {
128+
bf16_kernel::cosine_bf16_lsx(x.as_ptr(), x_norm, y.as_ptr(), y.len() as u32)
129+
},
130+
_ => cosine_scalar(x, x_norm, y),
131+
}
132+
}
133+
}
76134

77135
#[cfg(feature = "fp16kernels")]
78136
mod kernel {

rust/lance-linalg/src/distance/dot.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,56 @@ pub trait Dot: Num {
7575
fn dot(x: &[Self], y: &[Self]) -> f32;
7676
}
7777

78+
#[cfg(feature = "fp16kernels")]
79+
mod bf16_kernel {
80+
use half::bf16;
81+
82+
// These are the `dot_bf16` function in bf16.c. Our build.rs script compiles
83+
// a version of this file for each SIMD level with different suffixes.
84+
unsafe extern "C" {
85+
#[cfg(target_arch = "aarch64")]
86+
pub fn dot_bf16_neon(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
87+
#[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
88+
pub fn dot_bf16_avx512(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
89+
#[cfg(target_arch = "x86_64")]
90+
pub fn dot_bf16_avx2(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
91+
#[cfg(target_arch = "loongarch64")]
92+
pub fn dot_bf16_lsx(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
93+
#[cfg(target_arch = "loongarch64")]
94+
pub fn dot_bf16_lasx(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
95+
}
96+
}
97+
7898
impl Dot for bf16 {
7999
#[inline]
80100
fn dot(x: &[Self], y: &[Self]) -> f32 {
81-
dot_scalar::<Self, f32, 32>(x, y)
101+
match *SIMD_SUPPORT {
102+
#[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
103+
SimdSupport::Neon => unsafe {
104+
bf16_kernel::dot_bf16_neon(x.as_ptr(), y.as_ptr(), x.len() as u32)
105+
},
106+
#[cfg(all(
107+
feature = "fp16kernels",
108+
kernel_support = "avx512",
109+
target_arch = "x86_64"
110+
))]
111+
SimdSupport::Avx512FP16 => unsafe {
112+
bf16_kernel::dot_bf16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32)
113+
},
114+
#[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
115+
SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe {
116+
bf16_kernel::dot_bf16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32)
117+
},
118+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
119+
SimdSupport::Lasx => unsafe {
120+
bf16_kernel::dot_bf16_lasx(x.as_ptr(), y.as_ptr(), x.len() as u32)
121+
},
122+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
123+
SimdSupport::Lsx => unsafe {
124+
bf16_kernel::dot_bf16_lsx(x.as_ptr(), y.as_ptr(), x.len() as u32)
125+
},
126+
_ => dot_scalar::<Self, f32, 32>(x, y),
127+
}
82128
}
83129
}
84130

rust/lance-linalg/src/distance/l2.rs

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,56 @@ impl L2 for u8 {
9797
}
9898
}
9999

100+
#[cfg(feature = "fp16kernels")]
101+
mod bf16_kernel {
102+
use half::bf16;
103+
104+
// These are the `l2_bf16` function in bf16.c. Our build.rs script compiles
105+
// a version of this file for each SIMD level with different suffixes.
106+
unsafe extern "C" {
107+
#[cfg(target_arch = "aarch64")]
108+
pub fn l2_bf16_neon(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
109+
#[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
110+
pub fn l2_bf16_avx512(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
111+
#[cfg(target_arch = "x86_64")]
112+
pub fn l2_bf16_avx2(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
113+
#[cfg(target_arch = "loongarch64")]
114+
pub fn l2_bf16_lsx(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
115+
#[cfg(target_arch = "loongarch64")]
116+
pub fn l2_bf16_lasx(ptr1: *const bf16, ptr2: *const bf16, len: u32) -> f32;
117+
}
118+
}
119+
100120
impl L2 for bf16 {
101121
#[inline]
102122
fn l2(x: &[Self], y: &[Self]) -> f32 {
103-
// TODO: add SIMD support
104-
l2_scalar::<Self, f32, 16>(x, y)
123+
match *SIMD_SUPPORT {
124+
#[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
125+
SimdSupport::Neon => unsafe {
126+
bf16_kernel::l2_bf16_neon(x.as_ptr(), y.as_ptr(), x.len() as u32)
127+
},
128+
#[cfg(all(
129+
feature = "fp16kernels",
130+
kernel_support = "avx512",
131+
target_arch = "x86_64"
132+
))]
133+
SimdSupport::Avx512FP16 => unsafe {
134+
bf16_kernel::l2_bf16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32)
135+
},
136+
#[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
137+
SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe {
138+
bf16_kernel::l2_bf16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32)
139+
},
140+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
141+
SimdSupport::Lasx => unsafe {
142+
bf16_kernel::l2_bf16_lasx(x.as_ptr(), y.as_ptr(), x.len() as u32)
143+
},
144+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
145+
SimdSupport::Lsx => unsafe {
146+
bf16_kernel::l2_bf16_lsx(x.as_ptr(), y.as_ptr(), x.len() as u32)
147+
},
148+
_ => l2_scalar::<Self, f32, 16>(x, y),
149+
}
105150
}
106151
}
107152

rust/lance-linalg/src/distance/norm_l2.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,54 @@ impl Normalize for f16 {
8080
}
8181
}
8282

83+
#[cfg(feature = "fp16kernels")]
84+
mod bf16_kernel {
85+
use half::bf16;
86+
87+
unsafe extern "C" {
88+
#[cfg(target_arch = "aarch64")]
89+
pub fn norm_l2_bf16_neon(ptr: *const bf16, len: u32) -> f32;
90+
#[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
91+
pub fn norm_l2_bf16_avx512(ptr: *const bf16, len: u32) -> f32;
92+
#[cfg(target_arch = "x86_64")]
93+
pub fn norm_l2_bf16_avx2(ptr: *const bf16, len: u32) -> f32;
94+
#[cfg(target_arch = "loongarch64")]
95+
pub fn norm_l2_bf16_lsx(ptr: *const bf16, len: u32) -> f32;
96+
#[cfg(target_arch = "loongarch64")]
97+
pub fn norm_l2_bf16_lasx(ptr: *const bf16, len: u32) -> f32;
98+
}
99+
}
100+
83101
impl Normalize for bf16 {
84102
#[inline]
85103
fn norm_l2(vector: &[Self]) -> f32 {
86-
norm_l2_impl::<Self, f32, 32>(vector)
104+
match *SIMD_SUPPORT {
105+
#[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
106+
SimdSupport::Neon => unsafe {
107+
bf16_kernel::norm_l2_bf16_neon(vector.as_ptr(), vector.len() as u32)
108+
},
109+
#[cfg(all(
110+
feature = "fp16kernels",
111+
kernel_support = "avx512",
112+
target_arch = "x86_64"
113+
))]
114+
SimdSupport::Avx512FP16 => unsafe {
115+
bf16_kernel::norm_l2_bf16_avx512(vector.as_ptr(), vector.len() as u32)
116+
},
117+
#[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
118+
SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe {
119+
bf16_kernel::norm_l2_bf16_avx2(vector.as_ptr(), vector.len() as u32)
120+
},
121+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
122+
SimdSupport::Lasx => unsafe {
123+
bf16_kernel::norm_l2_bf16_lasx(vector.as_ptr(), vector.len() as u32)
124+
},
125+
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
126+
SimdSupport::Lsx => unsafe {
127+
bf16_kernel::norm_l2_bf16_lsx(vector.as_ptr(), vector.len() as u32)
128+
},
129+
_ => norm_l2_impl::<Self, f32, 32>(vector),
130+
}
87131
}
88132
}
89133

0 commit comments

Comments
 (0)