Skip to content

Commit a5c8943

Browse files
AdaWorldAPIclaude
andauthored
feat(simd): I8/I16 SIMD vectors + slice-level int ops (#124, sprint A3)
Adds the signed-byte / signed-half SIMD parity surface for the burn↔ndarray sprint: Item 4 — types • simd_avx512.rs: native I8x64 (__m512i) + I16x32 (__m512i) via AVX-512BW intrinsics (add/sub/min/max/cmp_gt/saturating/abs/neg). Plus AVX2-native I8x32 / I16x16 (__m256i) so the 256-bit signed types live in the same module as F32x8 / F64x4. • simd_avx2.rs: scalar-array polyfills for I8x64 / I16x32 (the AVX2 tier doesn't have a 64-byte signed type) and re-exports of the AVX2-native I8x32 / I16x16 from simd_avx512.rs for unified imports. • simd_neon.rs: NEON-native I8x16 (int8x16_t) + I16x8 (int16x8_t) via vaddq_s8 / vminq_s8 / vcgtq_s8 + paired/quadrupled scalar polyfills for I8x32 / I8x64 / I16x16 / I16x32. • simd.rs: scalar fallbacks for non-x86_64/aarch64 targets and re-exports for every active tier so consumers write use ndarray::simd::{I8x32, I8x64, I16x16, I16x32}; Item 5 — slice ops (new file simd_int_ops.rs) add_i8 / add_i16 / sub_i8 / sub_i16 (mutate-in-place, wrapping) dot_i8 -> i32 (overflow-safe accumulator) dot_i16 -> i64 (overflow-safe accumulator) min_i8 / max_i8 / min_i16 / max_i16 Each chunks via the natural SIMD width of the active tier (64-byte AVX-512BW when available, 32-byte AVX2, 16-byte NEON) and finishes with a scalar tail. Tests (+21 lib tests vs master baseline 1741 -> 1762): • simd_avx512::int_simd_tests: 9 tests (gated on target_feature=avx512f) pair-sum 64, signed boundaries, cmp_gt mask, saturating arithmetic. • simd_int_ops::tests: 11 tests misaligned tail lengths (63/65/127/129), 127i8 dot 127i8 x 64 overflow safety, signed boundary min/max, empty-slice identity. • simd_avx2 polyfill build verified with RUSTFLAGS="-C target-feature=-avx512f". Build host (this commit): AVX2 path (no avx512f at compile time -> uses the polyfill in simd_avx2.rs and simd.rs scalar mod for I8x64/I16x32). Co-authored-by: Claude <noreply@anthropic.com>
1 parent f7d2406 commit a5c8943

6 files changed

Lines changed: 1399 additions & 3 deletions

File tree

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ pub mod simd_neon;
252252
#[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)]
253253
pub mod simd_wasm;
254254

255+
/// Slice-level integer SIMD ops (i8/i16) — `add_i8`, `dot_i8`, `min_i8`, …
256+
#[cfg(feature = "std")]
257+
#[allow(missing_docs)]
258+
pub mod simd_int_ops;
259+
255260
/// Pluggable linear algebra backends (native SIMD, MKL, OpenBLAS).
256261
#[cfg(feature = "std")]
257262
pub mod backend;

src/simd.rs

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,14 @@ pub const PREFERRED_I16_LANES: usize = 16;
190190

191191
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
192192
pub use crate::simd_avx512::{
193-
// 256-bit (AVX2 baseline, __m256/__m256d)
194-
F32x8, F64x4, f32x8, f64x4,
193+
// 256-bit (AVX2 baseline, __m256/__m256d/__m256i)
194+
F32x8, F64x4, I8x32, I16x16, f32x8, f64x4, i8x32, i16x16,
195195
// 512-bit (native AVX-512, __m512/__m512d/__m512i)
196196
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
197+
I8x64, I16x32,
197198
F32Mask16, F64Mask8,
198199
f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8,
200+
i8x64, i16x32,
199201
};
200202

201203
// BF16 types + batch conversion (always available — scalar fallback built in)
@@ -223,13 +225,15 @@ pub use crate::simd_avx512::{
223225
pub use crate::simd_avx512::{BF16x16, BF16x8};
224226

225227
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
226-
pub use crate::simd_avx512::{F32x8, F64x4, f32x8, f64x4};
228+
pub use crate::simd_avx512::{F32x8, F64x4, I8x32, I16x16, f32x8, f64x4, i8x32, i16x16};
227229

228230
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
229231
pub use crate::simd_avx2::{
230232
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
233+
I8x64, I16x32,
231234
F32Mask16, F64Mask8,
232235
f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8,
236+
i8x64, i16x32,
233237
};
234238

235239
// ============================================================================
@@ -630,6 +634,62 @@ pub(crate) mod scalar {
630634
impl_int_type!(U32x16, u32, 16, 0u32);
631635
impl_int_type!(U64x8, u64, 8, 0u64);
632636

637+
// I8/I16 SIMD types (scalar fallback)
638+
impl_int_type!(I8x64, i8, 64, 0i8);
639+
impl_int_type!(I8x32, i8, 32, 0i8);
640+
impl_int_type!(I16x32, i16, 32, 0i16);
641+
impl_int_type!(I16x16, i16, 16, 0i16);
642+
643+
// I8x64 / I8x32 / I16x32 / I16x16 — AVX-512BW-style methods (scalar shape)
644+
impl I8x64 {
645+
#[inline(always)] pub fn zero() -> Self { Self([0i8; 64]) }
646+
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
647+
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
648+
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
649+
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
650+
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u64 {
651+
let mut m: u64 = 0;
652+
for i in 0..64 { if self.0[i] > other.0[i] { m |= 1u64 << i; } }
653+
m
654+
}
655+
}
656+
impl I8x32 {
657+
#[inline(always)] pub fn zero() -> Self { Self([0i8; 32]) }
658+
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
659+
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
660+
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
661+
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
662+
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u32 {
663+
let mut m: u32 = 0;
664+
for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } }
665+
m
666+
}
667+
}
668+
impl I16x32 {
669+
#[inline(always)] pub fn zero() -> Self { Self([0i16; 32]) }
670+
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
671+
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
672+
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
673+
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
674+
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u32 {
675+
let mut m: u32 = 0;
676+
for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } }
677+
m
678+
}
679+
}
680+
impl I16x16 {
681+
#[inline(always)] pub fn zero() -> Self { Self([0i16; 16]) }
682+
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
683+
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
684+
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
685+
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
686+
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u16 {
687+
let mut m: u16 = 0;
688+
for i in 0..16 { if self.0[i] > other.0[i] { m |= 1u16 << i; } }
689+
m
690+
}
691+
}
692+
633693
// Extra methods for U16x32 (widen/narrow, shift, multiply)
634694
impl U16x32 {
635695
#[inline(always)]
@@ -1012,6 +1072,10 @@ pub(crate) mod scalar {
10121072
#[allow(non_camel_case_types)] pub type u64x8 = U64x8;
10131073
#[allow(non_camel_case_types)] pub type f32x8 = F32x8;
10141074
#[allow(non_camel_case_types)] pub type f64x4 = F64x4;
1075+
#[allow(non_camel_case_types)] pub type i8x64 = I8x64;
1076+
#[allow(non_camel_case_types)] pub type i8x32 = I8x32;
1077+
#[allow(non_camel_case_types)] pub type i16x32 = I16x32;
1078+
#[allow(non_camel_case_types)] pub type i16x16 = I16x16;
10151079
}
10161080

10171081
// aarch64: F32x16/F64x8 come from the real NEON paired-load implementation
@@ -1036,9 +1100,11 @@ pub use scalar::{
10361100
pub use scalar::{
10371101
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
10381102
F32x8, F64x4,
1103+
I8x64, I8x32, I16x32, I16x16,
10391104
F32Mask16, F64Mask8,
10401105
f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8,
10411106
f32x8, f64x4,
1107+
i8x64, i8x32, i16x32, i16x16,
10421108
};
10431109

10441110
// Scalar BF16 conversion — always available on all platforms

src/simd_avx2.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
99
use crate::simd_avx512::{f32x8, f64x4};
1010

11+
// AVX2-native I8x32 / I16x16 live in simd_avx512.rs (256-bit __m256i types).
12+
// Re-export so consumers see a unified `crate::simd_avx2::I8x32` symbol.
13+
pub use crate::simd_avx512::{I8x32, I16x16, i8x32, i16x16};
14+
1115
// ============================================================================
1216
// AVX2 lane counts (half of AVX-512)
1317
// ============================================================================
@@ -772,6 +776,47 @@ macro_rules! avx2_int_type {
772776
}
773777

774778
avx2_int_type!(U8x64, u8, 64, 0u8);
779+
avx2_int_type!(I8x64, i8, 64, 0i8);
780+
avx2_int_type!(I16x32, i16, 32, 0i16);
781+
782+
// I8x64 / I16x32: AVX2 scalar polyfill — methods matching the AVX-512BW API
783+
impl I8x64 {
784+
#[inline(always)]
785+
pub fn zero() -> Self { Self([0i8; 64]) }
786+
#[inline(always)]
787+
pub fn add(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
788+
#[inline(always)]
789+
pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
790+
#[inline(always)]
791+
pub fn min(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
792+
#[inline(always)]
793+
pub fn max(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
794+
#[inline(always)]
795+
pub fn cmp_gt(self, other: Self) -> u64 {
796+
let mut m: u64 = 0;
797+
for i in 0..64 { if self.0[i] > other.0[i] { m |= 1u64 << i; } }
798+
m
799+
}
800+
}
801+
802+
impl I16x32 {
803+
#[inline(always)]
804+
pub fn zero() -> Self { Self([0i16; 32]) }
805+
#[inline(always)]
806+
pub fn add(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
807+
#[inline(always)]
808+
pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
809+
#[inline(always)]
810+
pub fn min(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
811+
#[inline(always)]
812+
pub fn max(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
813+
#[inline(always)]
814+
pub fn cmp_gt(self, other: Self) -> u32 {
815+
let mut m: u32 = 0;
816+
for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } }
817+
m
818+
}
819+
}
775820

776821
// ── U8x64 byte-level operations (scalar fallback for AVX2 tier) ──────────
777822
// These match the AVX-512 U8x64 methods in simd_avx512.rs.
@@ -1007,6 +1052,10 @@ pub type i64x8 = I64x8;
10071052
pub type u32x16 = U32x16;
10081053
#[allow(non_camel_case_types)]
10091054
pub type u64x8 = U64x8;
1055+
#[allow(non_camel_case_types)]
1056+
pub type i8x64 = I8x64;
1057+
#[allow(non_camel_case_types)]
1058+
pub type i16x32 = I16x32;
10101059

10111060
#[cfg(test)]
10121061
mod tests {

0 commit comments

Comments
 (0)