Skip to content

Commit bad2a55

Browse files
committed
feat: U8x64 byte-level ops for palette codec, nibble, byte scan (Pumpkin/SD)
Added to all three tiers (AVX-512 / AVX2 / scalar): cmpeq_mask(other) → u64 — byte-wise equality, returns bitmask shr_epi16(imm) → Self — shift right 16-bit lanes (nibble extract) saturating_sub(other) — max(a-b, 0) per byte (delta subtraction) unpack_lo_epi8(other) — interleave low bytes (nibble interleave) unpack_hi_epi8(other) — interleave high bytes These operations are used by: palette_codec.rs — Minecraft-style variable-width bit packing nibble.rs — 4-bit light level packing (Pumpkin) byte_scan.rs — NBT format byte scanning (future) stable_diffusion/ — VAE latent palette encoding via GGUF All three are currently using raw _mm256_/_mm512_ intrinsics. Next step: rewire them to use crate::simd::U8x64 instead. https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp
1 parent 1b06969 commit bad2a55

3 files changed

Lines changed: 150 additions & 4 deletions

File tree

src/simd.rs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -713,22 +713,51 @@ mod scalar {
713713
fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; }
714714
}
715715

716-
// U8x64 extra methods
716+
// U8x64 extra methods — byte-level operations for palette codec, nibble, byte scan
717717
impl U8x64 {
718718
#[inline(always)]
719719
pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap_or(&0) }
720720
#[inline(always)]
721721
pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap_or(&0) }
722722
#[inline(always)]
723723
pub fn simd_min(self, other: Self) -> Self {
724+
let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].min(other.0[i]); } Self(out)
725+
}
726+
#[inline(always)]
727+
pub fn simd_max(self, other: Self) -> Self {
728+
let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].max(other.0[i]); } Self(out)
729+
}
730+
#[inline(always)]
731+
pub fn cmpeq_mask(self, other: Self) -> u64 {
732+
let mut mask = 0u64;
733+
for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } }
734+
mask
735+
}
736+
#[inline(always)]
737+
pub fn shr_epi16(self, imm: u32) -> Self {
724738
let mut out = [0u8; 64];
725-
for i in 0..64 { out[i] = self.0[i].min(other.0[i]); }
739+
for i in (0..64).step_by(2) {
740+
let val = u16::from_le_bytes([self.0[i], self.0[i + 1]]);
741+
let shifted = val >> imm;
742+
let bytes = shifted.to_le_bytes();
743+
out[i] = bytes[0]; out[i + 1] = bytes[1];
744+
}
726745
Self(out)
727746
}
728747
#[inline(always)]
729-
pub fn simd_max(self, other: Self) -> Self {
748+
pub fn saturating_sub(self, other: Self) -> Self {
749+
let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); } Self(out)
750+
}
751+
#[inline(always)]
752+
pub fn unpack_lo_epi8(self, other: Self) -> Self {
753+
let mut out = [0u8; 64];
754+
for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+i]; out[b+i*2+1] = other.0[b+i]; } }
755+
Self(out)
756+
}
757+
#[inline(always)]
758+
pub fn unpack_hi_epi8(self, other: Self) -> Self {
730759
let mut out = [0u8; 64];
731-
for i in 0..64 { out[i] = self.0[i].max(other.0[i]); }
760+
for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+8+i]; out[b+i*2+1] = other.0[b+8+i]; } }
732761
Self(out)
733762
}
734763
}

src/simd_avx2.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,76 @@ macro_rules! avx2_int_type {
761761
}
762762

763763
avx2_int_type!(U8x64, u8, 64, 0u8);
764+
765+
// ── U8x64 byte-level operations (scalar fallback for AVX2 tier) ──────────
766+
// These match the AVX-512 U8x64 methods in simd_avx512.rs.
767+
impl U8x64 {
768+
/// Byte-wise equality mask: bit i set if self[i] == other[i].
769+
#[inline(always)]
770+
pub fn cmpeq_mask(self, other: Self) -> u64 {
771+
let mut mask = 0u64;
772+
for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } }
773+
mask
774+
}
775+
776+
/// Shift right each 16-bit lane by imm bits (operates on pairs of u8 as u16).
777+
#[inline(always)]
778+
pub fn shr_epi16(self, imm: u32) -> Self {
779+
let mut out = [0u8; 64];
780+
for i in (0..64).step_by(2) {
781+
let val = u16::from_le_bytes([self.0[i], self.0[i + 1]]);
782+
let shifted = val >> imm;
783+
let bytes = shifted.to_le_bytes();
784+
out[i] = bytes[0];
785+
out[i + 1] = bytes[1];
786+
}
787+
Self(out)
788+
}
789+
790+
/// Saturating unsigned subtraction: max(a - b, 0) per byte.
791+
#[inline(always)]
792+
pub fn saturating_sub(self, other: Self) -> Self {
793+
let mut out = [0u8; 64];
794+
for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); }
795+
Self(out)
796+
}
797+
798+
/// Interleave low bytes within each 128-bit lane.
799+
#[inline(always)]
800+
pub fn unpack_lo_epi8(self, other: Self) -> Self {
801+
let mut out = [0u8; 64];
802+
// Operates per 16-byte lane (4 lanes in 512-bit)
803+
for lane in 0..4 {
804+
let base = lane * 16;
805+
for i in 0..8 {
806+
out[base + i * 2] = self.0[base + i];
807+
out[base + i * 2 + 1] = other.0[base + i];
808+
}
809+
}
810+
Self(out)
811+
}
812+
813+
/// Interleave high bytes within each 128-bit lane.
814+
#[inline(always)]
815+
pub fn unpack_hi_epi8(self, other: Self) -> Self {
816+
let mut out = [0u8; 64];
817+
for lane in 0..4 {
818+
let base = lane * 16;
819+
for i in 0..8 {
820+
out[base + i * 2] = self.0[base + 8 + i];
821+
out[base + i * 2 + 1] = other.0[base + 8 + i];
822+
}
823+
}
824+
Self(out)
825+
}
826+
827+
/// Reduce min/max (not in macro).
828+
#[inline(always)] pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap() }
829+
#[inline(always)] pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap() }
830+
#[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
831+
#[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
832+
}
833+
764834
avx2_int_type!(I32x16, i32, 16, 0i32);
765835
avx2_int_type!(I64x8, i64, 8, 0i64);
766836
avx2_int_type!(U32x16, u32, 16, 0u32);

src/simd_avx512.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,53 @@ impl U8x64 {
576576
pub fn simd_max(self, other: Self) -> Self {
577577
Self(unsafe { _mm512_max_epu8(self.0, other.0) })
578578
}
579+
580+
// ── Byte-level operations for palette codec, nibble, byte scan ──────
581+
// Reference: Pumpkin/Minecraft-derived modules (palette_codec.rs,
582+
// nibble.rs, byte_scan.rs) use these for 4-bit packing and scanning.
583+
584+
/// Byte-wise equality comparison. Returns 64-bit mask: bit i set if a[i] == b[i].
585+
#[inline(always)]
586+
pub fn cmpeq_mask(self, other: Self) -> u64 {
587+
unsafe { _mm512_cmpeq_epi8_mask(self.0, other.0) }
588+
}
589+
590+
/// Shift right each 16-bit lane by immediate bits (for nibble extraction).
591+
/// Note: operates on 16-bit lanes, not 8-bit — matches _mm512_srli_epi16.
592+
#[inline(always)]
593+
pub fn shr_epi16(self, imm: u32) -> Self {
594+
// _mm512_srli_epi16 shifts each 16-bit lane right
595+
// Use match for const immediate (intrinsic requires const)
596+
Self(unsafe { match imm {
597+
1 => _mm512_srli_epi16(self.0, 1),
598+
2 => _mm512_srli_epi16(self.0, 2),
599+
3 => _mm512_srli_epi16(self.0, 3),
600+
4 => _mm512_srli_epi16(self.0, 4),
601+
5 => _mm512_srli_epi16(self.0, 5),
602+
6 => _mm512_srli_epi16(self.0, 6),
603+
7 => _mm512_srli_epi16(self.0, 7),
604+
8 => _mm512_srli_epi16(self.0, 8),
605+
_ => _mm512_setzero_si512(),
606+
}})
607+
}
608+
609+
/// Saturating unsigned subtraction: max(a - b, 0) per byte.
610+
#[inline(always)]
611+
pub fn saturating_sub(self, other: Self) -> Self {
612+
Self(unsafe { _mm512_subs_epu8(self.0, other.0) })
613+
}
614+
615+
/// Interleave low bytes: [a0,b0,a1,b1,...] from lower halves.
616+
#[inline(always)]
617+
pub fn unpack_lo_epi8(self, other: Self) -> Self {
618+
Self(unsafe { _mm512_unpacklo_epi8(self.0, other.0) })
619+
}
620+
621+
/// Interleave high bytes: [a8,b8,a9,b9,...] from upper halves.
622+
#[inline(always)]
623+
pub fn unpack_hi_epi8(self, other: Self) -> Self {
624+
Self(unsafe { _mm512_unpackhi_epi8(self.0, other.0) })
625+
}
579626
}
580627

581628
// u8 add/sub use AVX-512BW instructions

0 commit comments

Comments
 (0)