Skip to content

Commit 0eaaa86

Browse files
committed
refactor(blasgraph): route Hamming/popcount through ndarray::hpc::bitwise
Under `ndarray-hpc`, `dispatch_hamming`/`dispatch_popcount` and the typed `hamming_distance_dispatch` now call `ndarray::hpc::bitwise::{hamming_distance_raw, popcount_raw}` (the canonical VPOPCNTDQ → AVX-512BW → AVX2 → scalar dispatch), per the "all SIMD from ndarray" doctrine. The hand-rolled in-crate intrinsics survive only as the `#[cfg(not(feature = "ndarray-hpc"))]` fallback for minimal / non-x86 builds (CI, wasm, embedded). Mirrors the episodic.rs pattern. Validated: `cargo test -p lance-graph --lib blasgraph` → 194 passed, 0 failed (protoc installed to unblock the lance-encoding build script). https://claude.ai/code/session_01D2WSmezQBNC3bUdHuGfGmo
1 parent cfb530b commit 0eaaa86

2 files changed

Lines changed: 83 additions & 111 deletions

File tree

crates/lance-graph/src/graph/blasgraph/ndarray_bridge.rs

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -128,60 +128,80 @@ impl From<&ndarray::hpc::fingerprint::Fingerprint<256>> for BitVec {
128128
}
129129

130130
// ---------------------------------------------------------------------------
131-
// SIMD dispatch — 4-tier fallback matching ndarray's bitwise.rs pattern
131+
// SIMD dispatch — routed through ndarray under `ndarray-hpc`, else 4-tier
132+
// in-crate fallback. Per the "all SIMD from ndarray" doctrine the canonical
133+
// SIMD dispatch lives in `ndarray::hpc::bitwise`; the hand-rolled intrinsics
134+
// below survive only as the `#[cfg(not(feature = "ndarray-hpc"))]` fallback
135+
// for minimal / non-x86 builds (CI, wasm, embedded).
132136
// ---------------------------------------------------------------------------
133137

134138
/// SIMD-dispatched Hamming distance between two byte slices.
135139
///
136-
/// Computes `popcount(a XOR b)` using the best available instruction set:
137-
///
138-
/// 1. **VPOPCNTDQ** (AVX-512 VPOPCNTDQ) — 512-bit popcount in one instruction
139-
/// 2. **AVX-512BW** — 512-bit XOR + byte-level popcount via shuffle LUT
140-
/// 3. **AVX2** — 256-bit XOR + byte-level popcount via shuffle LUT
141-
/// 4. **Scalar** — word-by-word `count_ones()`
140+
/// Computes `popcount(a XOR b)`. Under `ndarray-hpc` this routes through
141+
/// `ndarray::hpc::bitwise::hamming_distance_raw` (VPOPCNTDQ → AVX-512BW →
142+
/// AVX2 → scalar). Without the feature it uses the in-crate 4-tier fallback
143+
/// (VPOPCNTDQ → AVX-512BW → AVX2 → scalar).
142144
///
143145
/// Both slices must have the same length. Panics otherwise.
144146
pub fn dispatch_hamming(a: &[u8], b: &[u8]) -> u64 {
145147
assert_eq!(a.len(), b.len(), "hamming: slices must have equal length");
146148

147-
#[cfg(target_arch = "x86_64")]
149+
#[cfg(feature = "ndarray-hpc")]
148150
{
149-
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512f") {
150-
// SAFETY: feature detection guarantees VPOPCNTDQ is available.
151-
return unsafe { hamming_avx512_vpopcntdq(a, b) };
152-
}
153-
if is_x86_feature_detected!("avx512bw") && is_x86_feature_detected!("avx512f") {
154-
// SAFETY: feature detection guarantees AVX-512BW is available.
155-
return unsafe { hamming_avx512bw(a, b) };
156-
}
157-
if is_x86_feature_detected!("avx2") {
158-
// SAFETY: feature detection guarantees AVX2 is available.
159-
return unsafe { hamming_avx2(a, b) };
160-
}
151+
// Lengths are equal (asserted above), so ndarray's `min(len)` is exact.
152+
ndarray::hpc::bitwise::hamming_distance_raw(a, b)
161153
}
162154

163-
hamming_scalar(a, b)
155+
#[cfg(not(feature = "ndarray-hpc"))]
156+
{
157+
#[cfg(target_arch = "x86_64")]
158+
{
159+
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512f") {
160+
// SAFETY: feature detection guarantees VPOPCNTDQ is available.
161+
return unsafe { hamming_avx512_vpopcntdq(a, b) };
162+
}
163+
if is_x86_feature_detected!("avx512bw") && is_x86_feature_detected!("avx512f") {
164+
// SAFETY: feature detection guarantees AVX-512BW is available.
165+
return unsafe { hamming_avx512bw(a, b) };
166+
}
167+
if is_x86_feature_detected!("avx2") {
168+
// SAFETY: feature detection guarantees AVX2 is available.
169+
return unsafe { hamming_avx2(a, b) };
170+
}
171+
}
172+
173+
hamming_scalar(a, b)
174+
}
164175
}
165176

166177
/// SIMD-dispatched population count over a byte slice.
167178
///
168-
/// Uses the same 4-tier fallback as `dispatch_hamming`:
169-
/// VPOPCNTDQ -> AVX-512BW -> AVX2 -> scalar.
179+
/// Under `ndarray-hpc` this routes through `ndarray::hpc::bitwise::
180+
/// popcount_raw`. Without the feature it uses the same in-crate 4-tier
181+
/// fallback as `dispatch_hamming` (VPOPCNTDQ → AVX-512BW → AVX2 → scalar).
170182
pub fn dispatch_popcount(a: &[u8]) -> u64 {
171-
#[cfg(target_arch = "x86_64")]
183+
#[cfg(feature = "ndarray-hpc")]
172184
{
173-
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512f") {
174-
return unsafe { popcount_avx512_vpopcntdq(a) };
175-
}
176-
if is_x86_feature_detected!("avx512bw") && is_x86_feature_detected!("avx512f") {
177-
return unsafe { popcount_avx512bw(a) };
178-
}
179-
if is_x86_feature_detected!("avx2") {
180-
return unsafe { popcount_avx2(a) };
181-
}
185+
ndarray::hpc::bitwise::popcount_raw(a)
182186
}
183187

184-
popcount_scalar(a)
188+
#[cfg(not(feature = "ndarray-hpc"))]
189+
{
190+
#[cfg(target_arch = "x86_64")]
191+
{
192+
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512f") {
193+
return unsafe { popcount_avx512_vpopcntdq(a) };
194+
}
195+
if is_x86_feature_detected!("avx512bw") && is_x86_feature_detected!("avx512f") {
196+
return unsafe { popcount_avx512bw(a) };
197+
}
198+
if is_x86_feature_detected!("avx2") {
199+
return unsafe { popcount_avx2(a) };
200+
}
201+
}
202+
203+
popcount_scalar(a)
204+
}
185205
}
186206

187207
// ---------------------------------------------------------------------------

crates/lance-graph/src/graph/blasgraph/types.rs

Lines changed: 29 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,16 @@ pub enum SelectOp {
420420

421421
// ─── SIMD-dispatched Hamming distance ─────────────────────────────────
422422
//
423-
// Dispatch chain: AVX-512 VPOPCNTDQ → AVX2 → scalar.
424-
// Uses `std::arch` intrinsics only, no external crate.
423+
// Per the "all SIMD from ndarray" doctrine, the SIMD dispatch lives in
424+
// `ndarray::hpc::bitwise` (VPOPCNTDQ → AVX-512BW → AVX2 → scalar) and is
425+
// routed in under the `ndarray-hpc` feature. The hand-rolled scalar path
426+
// below is the `#[cfg(not(feature = "ndarray-hpc"))]` fallback so minimal /
427+
// non-x86 builds (CI, wasm, embedded) keep working without the dep.
425428

426429
/// Scalar fallback: portable popcount via `count_ones()`.
430+
///
431+
/// Used as the `#[cfg(not(feature = "ndarray-hpc"))]` Hamming path and by the
432+
/// in-crate parity tests.
427433
fn hamming_distance_scalar(a: &[u64; VECTOR_WORDS], b: &[u64; VECTOR_WORDS]) -> u32 {
428434
let mut dist = 0u32;
429435
for i in 0..VECTOR_WORDS {
@@ -432,85 +438,31 @@ fn hamming_distance_scalar(a: &[u64; VECTOR_WORDS], b: &[u64; VECTOR_WORDS]) ->
432438
dist
433439
}
434440

435-
/// AVX2 implementation: processes 4 × u64 = 256 bits per iteration.
436-
/// Uses the Harley-Seal popcount algorithm on 256-bit XOR results.
437-
#[cfg(target_arch = "x86_64")]
438-
#[target_feature(enable = "avx2")]
439-
unsafe fn hamming_distance_avx2(a: &[u64; VECTOR_WORDS], b: &[u64; VECTOR_WORDS]) -> u32 {
440-
use std::arch::x86_64::*;
441-
442-
// Lookup table for 4-bit popcount
443-
let lookup = _mm256_setr_epi8(
444-
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3,
445-
3, 4,
446-
);
447-
let low_mask = _mm256_set1_epi8(0x0f);
448-
let mut total = _mm256_setzero_si256();
449-
450-
let a_ptr = a.as_ptr() as *const __m256i;
451-
let b_ptr = b.as_ptr() as *const __m256i;
452-
let n_vecs = VECTOR_WORDS / 4; // 256 / 4 = 64 iterations
453-
454-
for i in 0..n_vecs {
455-
let va = _mm256_loadu_si256(a_ptr.add(i));
456-
let vb = _mm256_loadu_si256(b_ptr.add(i));
457-
let xor = _mm256_xor_si256(va, vb);
458-
459-
// Popcount via lookup table (Mula et al.)
460-
let lo = _mm256_and_si256(xor, low_mask);
461-
let hi = _mm256_and_si256(_mm256_srli_epi16(xor, 4), low_mask);
462-
let popcnt_lo = _mm256_shuffle_epi8(lookup, lo);
463-
let popcnt_hi = _mm256_shuffle_epi8(lookup, hi);
464-
let popcnt = _mm256_add_epi8(popcnt_lo, popcnt_hi);
465-
466-
// Horizontal sum within bytes → u64 sums via sad
467-
let sad = _mm256_sad_epu8(popcnt, _mm256_setzero_si256());
468-
total = _mm256_add_epi64(total, sad);
469-
}
470-
471-
// Extract and sum the 4 u64 lanes
472-
let lo128 = _mm256_castsi256_si128(total);
473-
let hi128 = _mm256_extracti128_si256(total, 1);
474-
let sum128 = _mm_add_epi64(lo128, hi128);
475-
let upper = _mm_unpackhi_epi64(sum128, sum128);
476-
let final_sum = _mm_add_epi64(sum128, upper);
477-
_mm_cvtsi128_si64(final_sum) as u32
478-
}
479-
480-
/// AVX-512 VPOPCNTDQ implementation: processes 8 × u64 = 512 bits per iteration.
481-
#[cfg(target_arch = "x86_64")]
482-
#[target_feature(enable = "avx512f,avx512vpopcntdq")]
483-
unsafe fn hamming_distance_avx512(a: &[u64; VECTOR_WORDS], b: &[u64; VECTOR_WORDS]) -> u32 {
484-
use std::arch::x86_64::*;
485-
486-
let mut total = _mm512_setzero_si512();
487-
let a_ptr = a.as_ptr() as *const __m512i;
488-
let b_ptr = b.as_ptr() as *const __m512i;
489-
let n_vecs = VECTOR_WORDS / 8; // 256 / 8 = 32 iterations
490-
491-
for i in 0..n_vecs {
492-
let va = _mm512_loadu_si512(a_ptr.add(i));
493-
let vb = _mm512_loadu_si512(b_ptr.add(i));
494-
let xor = _mm512_xor_si512(va, vb);
495-
let popcnt = _mm512_popcnt_epi64(xor);
496-
total = _mm512_add_epi64(total, popcnt);
441+
/// Runtime-dispatched Hamming distance.
442+
///
443+
/// Under `ndarray-hpc` this routes through `ndarray::hpc::bitwise::
444+
/// hamming_distance_raw` (the canonical SIMD dispatch shared with the rest of
445+
/// the Ada stack). Without the feature it falls back to the in-crate scalar
446+
/// path. Both views reinterpret the same `[u64; VECTOR_WORDS]` backing store
447+
/// as native-endian bytes; Hamming distance is a bit count and is therefore
448+
/// invariant under the (consistent) byte layout on both operands.
449+
fn hamming_distance_dispatch(a: &[u64; VECTOR_WORDS], b: &[u64; VECTOR_WORDS]) -> u32 {
450+
#[cfg(feature = "ndarray-hpc")]
451+
{
452+
const BYTE_LEN: usize = VECTOR_WORDS * 8;
453+
// SAFETY: `[u64; VECTOR_WORDS]` is plain-old-data with no padding; a
454+
// `&[u8]` view of the same `BYTE_LEN` bytes is always valid (u8 has
455+
// alignment 1). Same layout on both operands → bit count is exact.
456+
let a_bytes = unsafe { std::slice::from_raw_parts(a.as_ptr() as *const u8, BYTE_LEN) };
457+
let b_bytes = unsafe { std::slice::from_raw_parts(b.as_ptr() as *const u8, BYTE_LEN) };
458+
// Max distance is VECTOR_BITS (16384), well within u32.
459+
ndarray::hpc::bitwise::hamming_distance_raw(a_bytes, b_bytes) as u32
497460
}
498461

499-
_mm512_reduce_add_epi64(total) as u32
500-
}
501-
502-
/// Runtime-dispatched Hamming distance using best available SIMD.
503-
fn hamming_distance_dispatch(a: &[u64; VECTOR_WORDS], b: &[u64; VECTOR_WORDS]) -> u32 {
504-
#[cfg(target_arch = "x86_64")]
462+
#[cfg(not(feature = "ndarray-hpc"))]
505463
{
506-
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vpopcntdq") {
507-
return unsafe { hamming_distance_avx512(a, b) };
508-
}
509-
if is_x86_feature_detected!("avx2") {
510-
return unsafe { hamming_distance_avx2(a, b) };
511-
}
464+
hamming_distance_scalar(a, b)
512465
}
513-
hamming_distance_scalar(a, b)
514466
}
515467

516468
#[cfg(test)]

0 commit comments

Comments
 (0)