Skip to content

Commit e084042

Browse files
authored
Merge pull request #158 from AdaWorldAPI/claude/review-rustynum-pr-80-2zNy5
Integrate rustynum gestalt types into SPO architecture
2 parents f6a9e8a + 4941983 commit e084042

8 files changed

Lines changed: 1119 additions & 421 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ crewai-vendor = { package = "crewai", path = "../crewai-rust" }
211211
# -----------------------------------------------------------------------------
212212
rustynum-rs = { path = "../rustynum/rustynum-rs" }
213213
rustynum-core = { path = "../rustynum/rustynum-core", features = ["avx512"] }
214+
rustynum-bnn = { path = "../rustynum/rustynum-bnn", features = ["avx512"] }
214215
rustynum-arrow = { path = "../rustynum/rustynum-arrow", default-features = false, features = ["arrow"] }
215216
rustynum-holo = { path = "../rustynum/rustynum-holo", features = ["avx512"] }
216217
rustynum-clam = { path = "../rustynum/rustynum-clam", features = ["avx512"] }

rust-toolchain.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[toolchain]
2+
channel = "stable"

src/core/scent.rs

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -512,29 +512,8 @@ fn timestamp() -> u64 {
512512
.unwrap_or(0)
513513
}
514514

515-
// ========== SIMD Optimized Scent Scan ==========
516-
517-
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
518-
mod simd {
519-
use super::{BUCKETS, SCENT_BYTES, scent_distance};
520-
521-
/// SIMD-optimized scent scan (AVX2)
522-
/// Compares query against 256 scents, returns matching chunk IDs
523-
pub fn find_chunks_simd(
524-
scents: &[[u8; SCENT_BYTES]; BUCKETS],
525-
query: &[u8; SCENT_BYTES],
526-
threshold: u32,
527-
) -> Vec<u8> {
528-
// For now, fall back to scalar
529-
// TODO: Implement AVX2 version
530-
scents
531-
.iter()
532-
.enumerate()
533-
.filter(|(_, s)| scent_distance(s, query) <= threshold)
534-
.map(|(i, _)| i as u8)
535-
.collect()
536-
}
537-
}
515+
// NOTE: SIMD-optimized scent scan is handled via rustynum runtime dispatch.
516+
// No compile-time SIMD gates needed — rustynum detects AVX-512/AVX2 at runtime.
538517

539518
#[cfg(test)]
540519
mod tests {

src/core/simd.rs

Lines changed: 8 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
//! SIMD-accelerated Hamming distance computation.
22
//!
3-
//! Automatically selects the best implementation:
4-
//! - AVX-512 VPOPCNTDQ (Intel Ice Lake+, AMD Zen 4+)
5-
//! - AVX2 + manual popcount
6-
//! - NEON + CNT (ARM)
7-
//! - Scalar fallback
3+
//! All SIMD dispatch is handled by rustynum-core at runtime.
4+
//! This module is a thin wrapper — ladybug-rs NEVER reimplements SIMD.
5+
//!
6+
//! Dispatch path:
7+
//! AVX-512 VPOPCNTDQ → AVX2 Harley-Seal → scalar POPCNT
8+
//! (one binary, all CPUs, runtime CPUID detection)
89
910
use crate::FINGERPRINT_U64;
1011
use crate::core::Fingerprint;
1112

1213
/// Compute Hamming distance between two fingerprints.
1314
///
14-
/// Uses runtime-dispatched SIMD via rustynum (works on any x86_64 CPU
15-
/// without compile-time `-C target-feature`). Detects AVX-512 VPOPCNTDQ
16-
/// at runtime via `is_x86_feature_detected!()`.
15+
/// Delegates to rustynum's runtime-dispatched SIMD (AVX-512 → AVX2 → scalar).
1716
#[inline]
1817
pub fn hamming_distance(a: &Fingerprint, b: &Fingerprint) -> u32 {
1918
crate::core::rustynum_accel::fingerprint_hamming(a, b)
2019
}
2120

22-
/// Scalar implementation (works everywhere)
21+
/// Scalar reference implementation (for tests only).
2322
#[inline]
2423
pub fn hamming_scalar(a: &Fingerprint, b: &Fingerprint) -> u32 {
2524
let a_data = a.as_raw();
@@ -32,141 +31,6 @@ pub fn hamming_scalar(a: &Fingerprint, b: &Fingerprint) -> u32 {
3231
total
3332
}
3433

35-
/// AVX-512 with VPOPCNTDQ instruction (fastest)
36-
#[cfg(all(target_arch = "x86_64", target_feature = "avx512vpopcntdq"))]
37-
#[target_feature(enable = "avx512f", enable = "avx512vpopcntdq")]
38-
unsafe fn hamming_avx512(a: &Fingerprint, b: &Fingerprint) -> u32 {
39-
unsafe {
40-
use std::arch::x86_64::*;
41-
42-
let a_ptr = a.as_raw().as_ptr();
43-
let b_ptr = b.as_raw().as_ptr();
44-
45-
let mut sum = _mm512_setzero_si512();
46-
47-
// Process 8 u64 at a time (512 bits)
48-
let mut i = 0;
49-
while i + 8 <= FINGERPRINT_U64 {
50-
let va = _mm512_loadu_si512(a_ptr.add(i) as *const __m512i);
51-
let vb = _mm512_loadu_si512(b_ptr.add(i) as *const __m512i);
52-
let xor = _mm512_xor_si512(va, vb);
53-
let popcnt = _mm512_popcnt_epi64(xor);
54-
sum = _mm512_add_epi64(sum, popcnt);
55-
i += 8;
56-
}
57-
58-
// Horizontal sum
59-
let mut total = _mm512_reduce_add_epi64(sum) as u32;
60-
61-
// Handle remaining (256 % 8 = 0, no remainder at 16K)
62-
while i < FINGERPRINT_U64 {
63-
total += (*a_ptr.add(i) ^ *b_ptr.add(i)).count_ones();
64-
i += 1;
65-
}
66-
67-
total
68-
}
69-
}
70-
71-
/// AVX2 implementation (fallback for older x86_64)
72-
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
73-
#[target_feature(enable = "avx2")]
74-
unsafe fn hamming_avx2(a: &Fingerprint, b: &Fingerprint) -> u32 {
75-
unsafe {
76-
use std::arch::x86_64::*;
77-
78-
let a_ptr = a.as_raw().as_ptr();
79-
let b_ptr = b.as_raw().as_ptr();
80-
81-
// Lookup table for 4-bit popcount
82-
let lookup = _mm256_setr_epi8(
83-
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2,
84-
3, 3, 4,
85-
);
86-
let low_mask = _mm256_set1_epi8(0x0f);
87-
88-
let mut total_sum = _mm256_setzero_si256();
89-
90-
// Process 4 u64 at a time (256 bits)
91-
let mut i = 0;
92-
while i + 4 <= FINGERPRINT_U64 {
93-
let va = _mm256_loadu_si256(a_ptr.add(i) as *const __m256i);
94-
let vb = _mm256_loadu_si256(b_ptr.add(i) as *const __m256i);
95-
let xor = _mm256_xor_si256(va, vb);
96-
97-
// Popcount via lookup table
98-
let lo = _mm256_and_si256(xor, low_mask);
99-
let hi = _mm256_and_si256(_mm256_srli_epi16(xor, 4), low_mask);
100-
let popcnt_lo = _mm256_shuffle_epi8(lookup, lo);
101-
let popcnt_hi = _mm256_shuffle_epi8(lookup, hi);
102-
let popcnt = _mm256_add_epi8(popcnt_lo, popcnt_hi);
103-
104-
// Sum bytes
105-
let sad = _mm256_sad_epu8(popcnt, _mm256_setzero_si256());
106-
total_sum = _mm256_add_epi64(total_sum, sad);
107-
108-
i += 4;
109-
}
110-
111-
// Horizontal sum
112-
let sum_lo = _mm256_extracti128_si256(total_sum, 0);
113-
let sum_hi = _mm256_extracti128_si256(total_sum, 1);
114-
let sum128 = _mm_add_epi64(sum_lo, sum_hi);
115-
let mut total = (_mm_extract_epi64(sum128, 0) + _mm_extract_epi64(sum128, 1)) as u32;
116-
117-
// Handle remaining
118-
while i < FINGERPRINT_U64 {
119-
total += (*a_ptr.add(i) ^ *b_ptr.add(i)).count_ones();
120-
i += 1;
121-
}
122-
123-
total
124-
}
125-
}
126-
127-
/// ARM NEON implementation
128-
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
129-
#[target_feature(enable = "neon")]
130-
unsafe fn hamming_neon(a: &Fingerprint, b: &Fingerprint) -> u32 {
131-
unsafe {
132-
use std::arch::aarch64::*;
133-
134-
let a_ptr = a.as_raw().as_ptr() as *const u8;
135-
let b_ptr = b.as_raw().as_ptr() as *const u8;
136-
137-
let mut sum = vdupq_n_u64(0);
138-
139-
// Process 16 bytes at a time
140-
let mut i = 0;
141-
let byte_len = FINGERPRINT_U64 * 8;
142-
while i + 16 <= byte_len {
143-
let va = vld1q_u8(a_ptr.add(i));
144-
let vb = vld1q_u8(b_ptr.add(i));
145-
let xor = veorq_u8(va, vb);
146-
let cnt = vcntq_u8(xor); // Count bits per byte
147-
148-
// Sum to 64-bit
149-
let sum16 = vpaddlq_u8(cnt); // u8 -> u16
150-
let sum32 = vpaddlq_u16(sum16); // u16 -> u32
151-
let sum64 = vpaddlq_u32(sum32); // u32 -> u64
152-
sum = vaddq_u64(sum, sum64);
153-
154-
i += 16;
155-
}
156-
157-
// Horizontal sum
158-
let mut total = (vgetq_lane_u64(sum, 0) + vgetq_lane_u64(sum, 1)) as u32;
159-
160-
// Handle remaining bytes
161-
while i < byte_len {
162-
total += (*a_ptr.add(i) ^ *b_ptr.add(i)).count_ones();
163-
i += 1;
164-
}
165-
166-
total
167-
}
168-
}
169-
17034
/// Batch Hamming distance computation (parallel)
17135
#[cfg(feature = "parallel")]
17236
pub fn batch_hamming(query: &Fingerprint, corpus: &[Fingerprint]) -> Vec<u32> {

0 commit comments

Comments
 (0)