|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright The Lance Authors |
3 | 3 |
|
| 4 | +#[cfg(target_arch = "aarch64")] |
| 5 | +use std::arch::aarch64::*; |
4 | 6 | #[cfg(target_arch = "x86_64")] |
5 | 7 | use std::arch::x86_64::*; |
6 | 8 |
|
@@ -57,13 +59,28 @@ pub fn sum_4bit_dist_table( |
57 | 59 | ) |
58 | 60 | } |
59 | 61 | }, |
| 62 | + #[cfg(target_arch = "aarch64")] |
| 63 | + SimdSupport::Neon => unsafe { |
| 64 | + for i in (0..n).step_by(BATCH_SIZE) { |
| 65 | + sum_dist_table_32bytes_batch_neon( |
| 66 | + &codes[i * code_len..(i + BATCH_SIZE) * code_len], |
| 67 | + dist_table, |
| 68 | + &mut dists[i..i + BATCH_SIZE], |
| 69 | + ) |
| 70 | + } |
| 71 | + }, |
60 | 72 | _ => sum_4bit_dist_table_scalar(code_len, codes, dist_table, dists), |
61 | 73 | } |
62 | 74 | } |
63 | 75 |
|
64 | 76 | #[inline] |
65 | 77 | #[allow(unused)] |
66 | | -fn sum_4bit_dist_table_scalar(code_len: usize, codes: &[u8], dist_table: &[u8], dists: &mut [u16]) { |
| 78 | +pub fn sum_4bit_dist_table_scalar( |
| 79 | + code_len: usize, |
| 80 | + codes: &[u8], |
| 81 | + dist_table: &[u8], |
| 82 | + dists: &mut [u16], |
| 83 | +) { |
67 | 84 | for (vec_block_idx, blocks) in codes.chunks_exact(BATCH_SIZE * code_len).enumerate() { |
68 | 85 | for (sub_vec_idx, block) in blocks.chunks_exact(BATCH_SIZE).enumerate() { |
69 | 86 | let current_dist_table = &dist_table[sub_vec_idx * 2 * 16..(sub_vec_idx * 2 + 1) * 16]; |
@@ -159,6 +176,77 @@ unsafe fn sum_dist_table_32bytes_batch_avx2(codes: &[u8], dist_table: &[u8], dis |
159 | 176 | _mm256_storeu_si256(dists.as_mut_ptr().add(16) as *mut __m256i, dis1); |
160 | 177 | } |
161 | 178 |
|
| 179 | +#[cfg(target_arch = "aarch64")] |
| 180 | +#[inline] |
| 181 | +unsafe fn sum_dist_table_32bytes_batch_neon(codes: &[u8], dist_table: &[u8], dists: &mut [u16]) { |
| 182 | + let low_mask = vdupq_n_u8(0x0f); |
| 183 | + |
| 184 | + // 8 accumulators: 4 per 128-bit "lane" (lo = bytes 0..16, hi = bytes 16..32 of each block) |
| 185 | + let mut accu0_lo = vdupq_n_u16(0); |
| 186 | + let mut accu1_lo = vdupq_n_u16(0); |
| 187 | + let mut accu2_lo = vdupq_n_u16(0); |
| 188 | + let mut accu3_lo = vdupq_n_u16(0); |
| 189 | + let mut accu0_hi = vdupq_n_u16(0); |
| 190 | + let mut accu1_hi = vdupq_n_u16(0); |
| 191 | + let mut accu2_hi = vdupq_n_u16(0); |
| 192 | + let mut accu3_hi = vdupq_n_u16(0); |
| 193 | + |
| 194 | + let codes_ptr = codes.as_ptr(); |
| 195 | + let dt_ptr = dist_table.as_ptr(); |
| 196 | + |
| 197 | + for i in (0..codes.len()).step_by(32) { |
| 198 | + // Process lo lane: bytes [i..i+16] |
| 199 | + let c_lo = vld1q_u8(codes_ptr.add(i)); |
| 200 | + let lut_lo = vld1q_u8(dt_ptr.add(i)); |
| 201 | + |
| 202 | + let lo_lo = vandq_u8(c_lo, low_mask); |
| 203 | + let hi_lo = vshrq_n_u8::<4>(c_lo); |
| 204 | + |
| 205 | + let res_lo_lo = vqtbl1q_u8(lut_lo, lo_lo); |
| 206 | + let res_hi_lo = vqtbl1q_u8(lut_lo, hi_lo); |
| 207 | + |
| 208 | + accu0_lo = vaddq_u16(accu0_lo, vreinterpretq_u16_u8(res_lo_lo)); |
| 209 | + accu1_lo = vaddq_u16(accu1_lo, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_lo_lo))); |
| 210 | + accu2_lo = vaddq_u16(accu2_lo, vreinterpretq_u16_u8(res_hi_lo)); |
| 211 | + accu3_lo = vaddq_u16(accu3_lo, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_hi_lo))); |
| 212 | + |
| 213 | + // Process hi lane: bytes [i+16..i+32] |
| 214 | + let c_hi = vld1q_u8(codes_ptr.add(i + 16)); |
| 215 | + let lut_hi = vld1q_u8(dt_ptr.add(i + 16)); |
| 216 | + |
| 217 | + let lo_hi = vandq_u8(c_hi, low_mask); |
| 218 | + let hi_hi = vshrq_n_u8::<4>(c_hi); |
| 219 | + |
| 220 | + let res_lo_hi = vqtbl1q_u8(lut_hi, lo_hi); |
| 221 | + let res_hi_hi = vqtbl1q_u8(lut_hi, hi_hi); |
| 222 | + |
| 223 | + accu0_hi = vaddq_u16(accu0_hi, vreinterpretq_u16_u8(res_lo_hi)); |
| 224 | + accu1_hi = vaddq_u16(accu1_hi, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_lo_hi))); |
| 225 | + accu2_hi = vaddq_u16(accu2_hi, vreinterpretq_u16_u8(res_hi_hi)); |
| 226 | + accu3_hi = vaddq_u16(accu3_hi, vshrq_n_u16::<8>(vreinterpretq_u16_u8(res_hi_hi))); |
| 227 | + } |
| 228 | + |
| 229 | + // Merge: clean even bytes by subtracting the odd-byte bleed |
| 230 | + accu0_lo = vsubq_u16(accu0_lo, vshlq_n_u16::<8>(accu1_lo)); |
| 231 | + accu0_hi = vsubq_u16(accu0_hi, vshlq_n_u16::<8>(accu1_hi)); |
| 232 | + |
| 233 | + // Cross-lane merge: add lo and hi lane accumulators |
| 234 | + // This is the NEON equivalent of AVX2's permute2f128 + blend + add |
| 235 | + let dis0_even = vaddq_u16(accu0_lo, accu0_hi); |
| 236 | + let dis0_odd = vaddq_u16(accu1_lo, accu1_hi); |
| 237 | + vst1q_u16(dists.as_mut_ptr(), dis0_even); |
| 238 | + vst1q_u16(dists.as_mut_ptr().add(8), dis0_odd); |
| 239 | + |
| 240 | + // Same for hi-nibble accumulators (vectors 16..31) |
| 241 | + accu2_lo = vsubq_u16(accu2_lo, vshlq_n_u16::<8>(accu3_lo)); |
| 242 | + accu2_hi = vsubq_u16(accu2_hi, vshlq_n_u16::<8>(accu3_hi)); |
| 243 | + |
| 244 | + let dis1_even = vaddq_u16(accu2_lo, accu2_hi); |
| 245 | + let dis1_odd = vaddq_u16(accu3_lo, accu3_hi); |
| 246 | + vst1q_u16(dists.as_mut_ptr().add(16), dis1_even); |
| 247 | + vst1q_u16(dists.as_mut_ptr().add(24), dis1_odd); |
| 248 | +} |
| 249 | + |
162 | 250 | // We implement the AVX512 version in C because AVX512 is not stable yet in Rust, |
163 | 251 | // implement it in Rust once we upgrade rust to 1.89.0. |
164 | 252 | unsafe extern "C" { |
@@ -214,4 +302,83 @@ mod tests { |
214 | 302 | // so the distance is 2 * (dist_table[0x6] + dist_table[0xb + 16]) = 2*(7 + 12) = 38 |
215 | 303 | assert_eq!(dists[1], 38); |
216 | 304 | } |
| 305 | + |
| 306 | + /// Test that the SIMD path (NEON on ARM, AVX2 on x86) produces identical |
| 307 | + /// results to the scalar reference across a range of dimensions, including |
| 308 | + /// very large ones (up to DIM=65536). |
| 309 | + /// |
| 310 | + /// Note: dist_table values are capped to avoid u16 overflow, matching |
| 311 | + /// production behavior where values are quantized to a small range. |
| 312 | + /// (The scalar path uses saturating_add while SIMD uses wrapping add, |
| 313 | + /// so they diverge on overflow — but overflow never occurs with real |
| 314 | + /// quantized data.) |
| 315 | + #[test] |
| 316 | + fn test_simd_matches_scalar_varied_dimensions() { |
| 317 | + use rand::{Rng, SeedableRng}; |
| 318 | + let mut rng = rand::rngs::StdRng::seed_from_u64(42); |
| 319 | + |
| 320 | + // code_len = dim / 8 for 1-bit quantization; we test various code_lens |
| 321 | + // directly since that's what the function sees. |
| 322 | + // code_len=16 → DIM=128, code_len=192 → DIM=1536, |
| 323 | + // code_len=512 → DIM=4096, code_len=8192 → DIM=65536 |
| 324 | + for code_len in [2, 16, 96, 192, 512, 1024, 8192] { |
| 325 | + let n = BATCH_SIZE; // 32 vectors per batch |
| 326 | + |
| 327 | + // Each code byte produces 2 lookups; cap values so |
| 328 | + // 2 * code_len * max_val < u16::MAX. |
| 329 | + let max_val = (u16::MAX as usize / (2 * code_len)).min(255) as u8; |
| 330 | + |
| 331 | + let codes: Vec<u8> = (0..n * code_len).map(|_| rng.random::<u8>()).collect(); |
| 332 | + let dist_table: Vec<u8> = (0..BATCH_SIZE * code_len) |
| 333 | + .map(|_| rng.random_range(0..=max_val)) |
| 334 | + .collect(); |
| 335 | + |
| 336 | + let mut expected = vec![0u16; n]; |
| 337 | + sum_4bit_dist_table_scalar(code_len, &codes, &dist_table, &mut expected); |
| 338 | + |
| 339 | + let mut actual = vec![0u16; n]; |
| 340 | + sum_4bit_dist_table(n, code_len, &codes, &dist_table, &mut actual); |
| 341 | + |
| 342 | + assert_eq!( |
| 343 | + actual, |
| 344 | + expected, |
| 345 | + "SIMD and scalar mismatch for code_len={} (DIM={})", |
| 346 | + code_len, |
| 347 | + code_len * 8, |
| 348 | + ); |
| 349 | + } |
| 350 | + } |
| 351 | + |
| 352 | + /// Test with multiple batches to verify accumulation across batch boundaries. |
| 353 | + #[test] |
| 354 | + fn test_simd_matches_scalar_multi_batch() { |
| 355 | + use rand::{Rng, SeedableRng}; |
| 356 | + let mut rng = rand::rngs::StdRng::seed_from_u64(123); |
| 357 | + |
| 358 | + for code_len in [16, 192, 1024] { |
| 359 | + let n = BATCH_SIZE * 10; // 320 vectors = 10 batches |
| 360 | + |
| 361 | + let max_val = (u16::MAX as usize / (2 * code_len)).min(255) as u8; |
| 362 | + |
| 363 | + let codes: Vec<u8> = (0..n * code_len).map(|_| rng.random::<u8>()).collect(); |
| 364 | + let dist_table: Vec<u8> = (0..BATCH_SIZE * code_len) |
| 365 | + .map(|_| rng.random_range(0..=max_val)) |
| 366 | + .collect(); |
| 367 | + |
| 368 | + let mut expected = vec![0u16; n]; |
| 369 | + sum_4bit_dist_table_scalar(code_len, &codes, &dist_table, &mut expected); |
| 370 | + |
| 371 | + let mut actual = vec![0u16; n]; |
| 372 | + sum_4bit_dist_table(n, code_len, &codes, &dist_table, &mut actual); |
| 373 | + |
| 374 | + assert_eq!( |
| 375 | + actual, |
| 376 | + expected, |
| 377 | + "SIMD and scalar mismatch for multi-batch code_len={} (DIM={}, n={})", |
| 378 | + code_len, |
| 379 | + code_len * 8, |
| 380 | + n, |
| 381 | + ); |
| 382 | + } |
| 383 | + } |
217 | 384 | } |
0 commit comments