Skip to content

Commit 8ba065c

Browse files
authored
Merge pull request #76 from AdaWorldAPI/claude/setup-embedding-pipeline-Fa65C
feat: U8x64 byte-level ops for palette codec, nibble, byte scan (Pumpkin/SD) Added to all three tiers (AVX-512 / AVX2 / scalar): cmpeq_mask(other) → u64 — byte-wise equality, returns bitmask shr_epi16(imm) → Self — shift right 16-bit lanes (nibble extract) saturating_sub(other) — max(a-b, 0) per byte (delta subtraction) unpack_lo_epi8(other) — interleave low bytes (nibble interleave) unpack_hi_epi8(other) — interleave high bytes These operations are used by: palette_codec.rs — Minecraft-style variable-width bit packing nibble.rs — 4-bit light level packing (Pumpkin) byte_scan.rs — NBT format byte scanning (future) stable_diffusion/ — VAE latent palette encoding via GGUF All three are currently using raw _mm256_/_mm512_ intrinsics. Next step: rewire them to use crate::simd::U8x64 instead. https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp
2 parents 3bc3651 + bad2a55 commit 8ba065c

3 files changed

Lines changed: 492 additions & 4 deletions

File tree

src/simd.rs

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,64 @@ static TIER: LazyLock<Tier> = LazyLock::new(|| {
2222
#[inline(always)]
2323
fn tier() -> Tier { *TIER }
2424

25+
// BF16 tier detection happens inline in bf16_to_f32_batch() via
26+
// is_x86_feature_detected!("avx512bf16") — no LazyLock needed.
27+
// The check is cheap (reads a cached cpuid result) and the batch
28+
// function uses as_chunks::<16>() + as_chunks::<8>() for SIMD widths.
29+
30+
// ============================================================================
31+
// Preferred SIMD lane widths — compile-time constants for array_windows
32+
// ============================================================================
33+
//
34+
// Consumer code uses these to select array_windows size at compile time:
35+
//
36+
// for window in data.array_windows::<{crate::simd::PREFERRED_F64_LANES}>() {
37+
// let v = F64x8::from_array(*window); // AVX-512: native 8-wide
38+
// // or
39+
// let v = F64x4::from_array(*window); // AVX2: native 4-wide
40+
// }
41+
//
42+
// generic_const_exprs is nightly, so consumers must #[cfg] branch on window size.
43+
// These constants document the preferred width per tier.
44+
45+
/// Preferred f64 SIMD width (elements per register).
46+
/// AVX-512: 8 lanes (__m512d). AVX2/scalar: 4 lanes (__m256d).
47+
#[cfg(target_feature = "avx512f")]
48+
pub const PREFERRED_F64_LANES: usize = 8;
49+
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
50+
pub const PREFERRED_F64_LANES: usize = 4;
51+
#[cfg(not(target_arch = "x86_64"))]
52+
pub const PREFERRED_F64_LANES: usize = 4; // scalar fallback: same as AVX2 shape
53+
54+
/// Preferred f32 SIMD width.
55+
/// AVX-512: 16 lanes (__m512). AVX2/scalar: 8 lanes (__m256).
56+
#[cfg(target_feature = "avx512f")]
57+
pub const PREFERRED_F32_LANES: usize = 16;
58+
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
59+
pub const PREFERRED_F32_LANES: usize = 8;
60+
#[cfg(not(target_arch = "x86_64"))]
61+
pub const PREFERRED_F32_LANES: usize = 8;
62+
63+
/// Preferred u64 SIMD width.
64+
/// AVX-512: 8 lanes. AVX2/scalar: 4 lanes.
65+
#[cfg(target_feature = "avx512f")]
66+
pub const PREFERRED_U64_LANES: usize = 8;
67+
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
68+
pub const PREFERRED_U64_LANES: usize = 4;
69+
#[cfg(not(target_arch = "x86_64"))]
70+
pub const PREFERRED_U64_LANES: usize = 4;
71+
72+
/// Preferred i16 SIMD width (for Base17 L1 on i16[17]).
73+
/// AVX-512: 32 lanes (__m512i via epi16). AVX2: 16 lanes (__m256i).
74+
/// Base17 has 17 dims — AVX-512 covers 32 (load 17 + 15 padding),
75+
/// AVX2 covers 16 + 1 scalar.
76+
#[cfg(target_feature = "avx512f")]
77+
pub const PREFERRED_I16_LANES: usize = 32;
78+
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
79+
pub const PREFERRED_I16_LANES: usize = 16;
80+
#[cfg(not(target_arch = "x86_64"))]
81+
pub const PREFERRED_I16_LANES: usize = 16;
82+
2583
// ============================================================================
2684
// x86_64: re-export based on tier
2785
// ============================================================================
@@ -41,6 +99,16 @@ pub use crate::simd_avx512::{
4199
f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8,
42100
};
43101

102+
// BF16 types + batch conversion (always available — scalar fallback built in)
103+
#[cfg(target_arch = "x86_64")]
104+
pub use crate::simd_avx512::{
105+
bf16_to_f32_scalar, f32_to_bf16_scalar,
106+
bf16_to_f32_batch, f32_to_bf16_batch,
107+
};
108+
// BF16 SIMD types only available when avx512bf16 is enabled at compile time
109+
#[cfg(all(target_arch = "x86_64", target_feature = "avx512bf16"))]
110+
pub use crate::simd_avx512::{BF16x16, BF16x8};
111+
44112
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
45113
pub use crate::simd_avx512::{F32x8, F64x4, f32x8, f64x4};
46114

@@ -645,22 +713,51 @@ mod scalar {
645713
fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; }
646714
}
647715

648-
// U8x64 extra methods
716+
// U8x64 extra methods — byte-level operations for palette codec, nibble, byte scan
649717
impl U8x64 {
650718
#[inline(always)]
651719
pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap_or(&0) }
652720
#[inline(always)]
653721
pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap_or(&0) }
654722
#[inline(always)]
655723
pub fn simd_min(self, other: Self) -> Self {
724+
let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].min(other.0[i]); } Self(out)
725+
}
726+
#[inline(always)]
727+
pub fn simd_max(self, other: Self) -> Self {
728+
let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].max(other.0[i]); } Self(out)
729+
}
730+
#[inline(always)]
731+
pub fn cmpeq_mask(self, other: Self) -> u64 {
732+
let mut mask = 0u64;
733+
for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } }
734+
mask
735+
}
736+
#[inline(always)]
737+
pub fn shr_epi16(self, imm: u32) -> Self {
656738
let mut out = [0u8; 64];
657-
for i in 0..64 { out[i] = self.0[i].min(other.0[i]); }
739+
for i in (0..64).step_by(2) {
740+
let val = u16::from_le_bytes([self.0[i], self.0[i + 1]]);
741+
let shifted = val >> imm;
742+
let bytes = shifted.to_le_bytes();
743+
out[i] = bytes[0]; out[i + 1] = bytes[1];
744+
}
658745
Self(out)
659746
}
660747
#[inline(always)]
661-
pub fn simd_max(self, other: Self) -> Self {
748+
pub fn saturating_sub(self, other: Self) -> Self {
749+
let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); } Self(out)
750+
}
751+
#[inline(always)]
752+
pub fn unpack_lo_epi8(self, other: Self) -> Self {
662753
let mut out = [0u8; 64];
663-
for i in 0..64 { out[i] = self.0[i].max(other.0[i]); }
754+
for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+i]; out[b+i*2+1] = other.0[b+i]; } }
755+
Self(out)
756+
}
757+
#[inline(always)]
758+
pub fn unpack_hi_epi8(self, other: Self) -> Self {
759+
let mut out = [0u8; 64];
760+
for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+8+i]; out[b+i*2+1] = other.0[b+8+i]; } }
664761
Self(out)
665762
}
666763
}
@@ -697,6 +794,20 @@ pub use scalar::{
697794
f32x8, f64x4,
698795
};
699796

797+
// Scalar BF16 conversion — always available on all platforms
798+
#[cfg(not(target_arch = "x86_64"))]
799+
pub fn bf16_to_f32_scalar(bits: u16) -> f32 { f32::from_bits((bits as u32) << 16) }
800+
#[cfg(not(target_arch = "x86_64"))]
801+
pub fn f32_to_bf16_scalar(v: f32) -> u16 { (v.to_bits() >> 16) as u16 }
802+
#[cfg(not(target_arch = "x86_64"))]
803+
pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) {
804+
for (i, &b) in input.iter().enumerate() { if i < output.len() { output[i] = bf16_to_f32_scalar(b); } }
805+
}
806+
#[cfg(not(target_arch = "x86_64"))]
807+
pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) {
808+
for (i, &v) in input.iter().enumerate() { if i < output.len() { output[i] = f32_to_bf16_scalar(v); } }
809+
}
810+
700811
// ============================================================================
701812
// SIMD math functions — ndarray additions (not in std::simd)
702813
// ============================================================================

src/simd_avx2.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,76 @@ macro_rules! avx2_int_type {
761761
}
762762

763763
avx2_int_type!(U8x64, u8, 64, 0u8);
764+
765+
// ── U8x64 byte-level operations (scalar fallback for AVX2 tier) ──────────
766+
// These match the AVX-512 U8x64 methods in simd_avx512.rs.
767+
impl U8x64 {
768+
/// Byte-wise equality mask: bit i set if self[i] == other[i].
769+
#[inline(always)]
770+
pub fn cmpeq_mask(self, other: Self) -> u64 {
771+
let mut mask = 0u64;
772+
for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } }
773+
mask
774+
}
775+
776+
/// Shift right each 16-bit lane by imm bits (operates on pairs of u8 as u16).
777+
#[inline(always)]
778+
pub fn shr_epi16(self, imm: u32) -> Self {
779+
let mut out = [0u8; 64];
780+
for i in (0..64).step_by(2) {
781+
let val = u16::from_le_bytes([self.0[i], self.0[i + 1]]);
782+
let shifted = val >> imm;
783+
let bytes = shifted.to_le_bytes();
784+
out[i] = bytes[0];
785+
out[i + 1] = bytes[1];
786+
}
787+
Self(out)
788+
}
789+
790+
/// Saturating unsigned subtraction: max(a - b, 0) per byte.
791+
#[inline(always)]
792+
pub fn saturating_sub(self, other: Self) -> Self {
793+
let mut out = [0u8; 64];
794+
for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); }
795+
Self(out)
796+
}
797+
798+
/// Interleave low bytes within each 128-bit lane.
799+
#[inline(always)]
800+
pub fn unpack_lo_epi8(self, other: Self) -> Self {
801+
let mut out = [0u8; 64];
802+
// Operates per 16-byte lane (4 lanes in 512-bit)
803+
for lane in 0..4 {
804+
let base = lane * 16;
805+
for i in 0..8 {
806+
out[base + i * 2] = self.0[base + i];
807+
out[base + i * 2 + 1] = other.0[base + i];
808+
}
809+
}
810+
Self(out)
811+
}
812+
813+
/// Interleave high bytes within each 128-bit lane.
814+
#[inline(always)]
815+
pub fn unpack_hi_epi8(self, other: Self) -> Self {
816+
let mut out = [0u8; 64];
817+
for lane in 0..4 {
818+
let base = lane * 16;
819+
for i in 0..8 {
820+
out[base + i * 2] = self.0[base + 8 + i];
821+
out[base + i * 2 + 1] = other.0[base + 8 + i];
822+
}
823+
}
824+
Self(out)
825+
}
826+
827+
/// Reduce min/max (not in macro).
828+
#[inline(always)] pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap() }
829+
#[inline(always)] pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap() }
830+
#[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
831+
#[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
832+
}
833+
764834
avx2_int_type!(I32x16, i32, 16, 0i32);
765835
avx2_int_type!(I64x8, i64, 8, 0i64);
766836
avx2_int_type!(U32x16, u32, 16, 0u32);

0 commit comments

Comments
 (0)