Skip to content

Commit c1c7ae4

Browse files
committed
feat(simd): elementwise slice ops via polyfill dispatch (simd_ops.rs)
New src/simd_ops.rs — slice-level elementwise operations built on the polyfill SIMD types (F32x16/F64x8). No platform-specific code in this file; it uses operator traits (+, -, *, /) on the already-dispatched types so it works on AVX-512, AVX2, NEON, and scalar identically. Consumer surface: use ndarray::simd::{ add_f32, sub_f32, mul_f32, div_f32, add_f32_inplace, sub_f32_inplace, mul_f32_inplace, div_f32_inplace, scale_f32, add_scalar_f32, scale_f32_inplace, add_f64, mul_f64, add_f64_inplace, }; Each function: F32x16 chunks (16 elements/iteration) + scalar tail. Inplace variants modify dst in-place. Scale variants broadcast a scalar. 11 tests covering: aligned, misaligned tail, empty, mismatched lengths. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
1 parent 00b6ee5 commit c1c7ae4

3 files changed

Lines changed: 294 additions & 0 deletions

File tree

src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ pub mod simd_wasm;
257257
#[allow(missing_docs)]
258258
pub mod simd_int_ops;
259259

260+
/// Slice-level elementwise ops (f32/f64) built on the polyfill SIMD types.
261+
/// `add_f32`, `mul_f32`, `add_f32_inplace`, `scale_f32`, etc.
262+
/// Re-exported flat through `ndarray::simd::add_f32`.
263+
#[cfg(feature = "std")]
264+
pub mod simd_ops;
265+
260266
/// Half-precision SIMD vectors (`BF16x16`, `F16x16`) + slice-level ops.
261267
/// Depends on `hpc::quantized::{BF16, F16}` — needs `std` (hpc core).
262268
#[cfg(feature = "std")]

src/simd.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,15 @@ pub use crate::hpc::cam_pq::{kmeans, squared_l2};
12361236

12371237
pub use crate::hpc::heel_f64x8::cosine_f32_to_f64_simd;
12381238

1239+
// Elementwise slice ops — polyfill-dispatched (F32x16/F64x8 chunks + scalar tail).
1240+
#[cfg(feature = "std")]
1241+
pub use crate::simd_ops::{
1242+
add_f32, sub_f32, mul_f32, div_f32,
1243+
add_f32_inplace, sub_f32_inplace, mul_f32_inplace, div_f32_inplace,
1244+
scale_f32, add_scalar_f32, scale_f32_inplace,
1245+
add_f64, mul_f64, add_f64_inplace,
1246+
};
1247+
12391248
// ============================================================================
12401249
// Tests
12411250
// ============================================================================

src/simd_ops.rs

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
//! Slice-level elementwise ops built on the polyfill SIMD types.
2+
//!
3+
//! Every function uses `crate::simd::F32x16` (or the appropriate type),
4+
//! which is already dispatched: AVX-512 → AVX2 → NEON → scalar.
5+
//! These ops inherit that dispatch — no platform-specific code here.
6+
//!
7+
//! Re-exported flat through `ndarray::simd::add_f32`, etc.
8+
9+
use crate::simd::{F32x16, F64x8};
10+
11+
// ═══════════════════════════════════════════════════════════════════
12+
// f32 binary ops (out-of-place)
13+
// ═══════════════════════════════════════════════════════════════════
14+
15+
/// Elementwise add: `out[i] = a[i] + b[i]`.
16+
pub fn add_f32(a: &[f32], b: &[f32]) -> Vec<f32> {
17+
binary_f32(a, b, |x, y| x + y, |x, y| x + y)
18+
}
19+
20+
/// Elementwise subtract: `out[i] = a[i] - b[i]`.
21+
pub fn sub_f32(a: &[f32], b: &[f32]) -> Vec<f32> {
22+
binary_f32(a, b, |x, y| x - y, |x, y| x - y)
23+
}
24+
25+
/// Elementwise multiply: `out[i] = a[i] * b[i]`.
26+
pub fn mul_f32(a: &[f32], b: &[f32]) -> Vec<f32> {
27+
binary_f32(a, b, |x, y| x * y, |x, y| x * y)
28+
}
29+
30+
/// Elementwise divide: `out[i] = a[i] / b[i]`.
31+
pub fn div_f32(a: &[f32], b: &[f32]) -> Vec<f32> {
32+
binary_f32(a, b, |x, y| x / y, |x, y| x / y)
33+
}
34+
35+
// ═══════════════════════════════════════════════════════════════════
36+
// f32 inplace ops
37+
// ═══════════════════════════════════════════════════════════════════
38+
39+
/// Inplace add: `dst[i] += src[i]`.
40+
pub fn add_f32_inplace(dst: &mut [f32], src: &[f32]) {
41+
inplace_f32(dst, src, |d, s| d + s, |d, s| *d += s)
42+
}
43+
44+
/// Inplace subtract: `dst[i] -= src[i]`.
45+
pub fn sub_f32_inplace(dst: &mut [f32], src: &[f32]) {
46+
inplace_f32(dst, src, |d, s| d - s, |d, s| *d -= s)
47+
}
48+
49+
/// Inplace multiply: `dst[i] *= src[i]`.
50+
pub fn mul_f32_inplace(dst: &mut [f32], src: &[f32]) {
51+
inplace_f32(dst, src, |d, s| d * s, |d, s| *d *= s)
52+
}
53+
54+
/// Inplace divide: `dst[i] /= src[i]`.
55+
pub fn div_f32_inplace(dst: &mut [f32], src: &[f32]) {
56+
inplace_f32(dst, src, |d, s| d / s, |d, s| *d /= s)
57+
}
58+
59+
// ═══════════════════════════════════════════════════════════════════
60+
// f32 scalar ops
61+
// ═══════════════════════════════════════════════════════════════════
62+
63+
/// Scalar multiply: `out[i] = a[i] * scalar`.
64+
pub fn scale_f32(a: &[f32], scalar: f32) -> Vec<f32> {
65+
let s = F32x16::splat(scalar);
66+
let n = a.len();
67+
let mut out = vec![0.0f32; n];
68+
let mut i = 0;
69+
while i + 16 <= n {
70+
(F32x16::from_slice(&a[i..]) * s).copy_to_slice(&mut out[i..]);
71+
i += 16;
72+
}
73+
while i < n { out[i] = a[i] * scalar; i += 1; }
74+
out
75+
}
76+
77+
/// Scalar add: `out[i] = a[i] + scalar`.
78+
pub fn add_scalar_f32(a: &[f32], scalar: f32) -> Vec<f32> {
79+
let s = F32x16::splat(scalar);
80+
let n = a.len();
81+
let mut out = vec![0.0f32; n];
82+
let mut i = 0;
83+
while i + 16 <= n {
84+
(F32x16::from_slice(&a[i..]) + s).copy_to_slice(&mut out[i..]);
85+
i += 16;
86+
}
87+
while i < n { out[i] = a[i] + scalar; i += 1; }
88+
out
89+
}
90+
91+
/// Inplace scalar multiply: `a[i] *= scalar`.
92+
pub fn scale_f32_inplace(a: &mut [f32], scalar: f32) {
93+
let s = F32x16::splat(scalar);
94+
let n = a.len();
95+
let mut i = 0;
96+
while i + 16 <= n {
97+
(F32x16::from_slice(&a[i..]) * s).copy_to_slice(&mut a[i..]);
98+
i += 16;
99+
}
100+
while i < n { a[i] *= scalar; i += 1; }
101+
}
102+
103+
// ═══════════════════════════════════════════════════════════════════
104+
// f64 binary ops
105+
// ═══════════════════════════════════════════════════════════════════
106+
107+
/// Elementwise add f64: `out[i] = a[i] + b[i]`.
108+
pub fn add_f64(a: &[f64], b: &[f64]) -> Vec<f64> {
109+
binary_f64(a, b, |x, y| x + y, |x, y| x + y)
110+
}
111+
112+
/// Elementwise multiply f64: `out[i] = a[i] * b[i]`.
113+
pub fn mul_f64(a: &[f64], b: &[f64]) -> Vec<f64> {
114+
binary_f64(a, b, |x, y| x * y, |x, y| x * y)
115+
}
116+
117+
/// Inplace add f64: `dst[i] += src[i]`.
118+
pub fn add_f64_inplace(dst: &mut [f64], src: &[f64]) {
119+
inplace_f64(dst, src, |d, s| d + s, |d, s| *d += s)
120+
}
121+
122+
// ═══════════════════════════════════════════════════════════════════
123+
// Internal dispatch helpers
124+
// ═══════════════════════════════════════════════════════════════════
125+
126+
#[inline]
127+
fn binary_f32(
128+
a: &[f32], b: &[f32],
129+
simd_op: impl Fn(F32x16, F32x16) -> F32x16,
130+
scalar_op: impl Fn(f32, f32) -> f32,
131+
) -> Vec<f32> {
132+
let n = a.len().min(b.len());
133+
let mut out = vec![0.0f32; n];
134+
let mut i = 0;
135+
while i + 16 <= n {
136+
simd_op(F32x16::from_slice(&a[i..]), F32x16::from_slice(&b[i..]))
137+
.copy_to_slice(&mut out[i..]);
138+
i += 16;
139+
}
140+
while i < n { out[i] = scalar_op(a[i], b[i]); i += 1; }
141+
out
142+
}
143+
144+
#[inline]
145+
fn inplace_f32(
146+
dst: &mut [f32], src: &[f32],
147+
simd_op: impl Fn(F32x16, F32x16) -> F32x16,
148+
scalar_op: impl Fn(&mut f32, f32),
149+
) {
150+
let n = dst.len().min(src.len());
151+
let mut i = 0;
152+
while i + 16 <= n {
153+
simd_op(F32x16::from_slice(&dst[i..]), F32x16::from_slice(&src[i..]))
154+
.copy_to_slice(&mut dst[i..]);
155+
i += 16;
156+
}
157+
while i < n { scalar_op(&mut dst[i], src[i]); i += 1; }
158+
}
159+
160+
#[inline]
161+
fn binary_f64(
162+
a: &[f64], b: &[f64],
163+
simd_op: impl Fn(F64x8, F64x8) -> F64x8,
164+
scalar_op: impl Fn(f64, f64) -> f64,
165+
) -> Vec<f64> {
166+
let n = a.len().min(b.len());
167+
let mut out = vec![0.0f64; n];
168+
let mut i = 0;
169+
while i + 8 <= n {
170+
simd_op(F64x8::from_slice(&a[i..]), F64x8::from_slice(&b[i..]))
171+
.copy_to_slice(&mut out[i..]);
172+
i += 8;
173+
}
174+
while i < n { out[i] = scalar_op(a[i], b[i]); i += 1; }
175+
out
176+
}
177+
178+
#[inline]
179+
fn inplace_f64(
180+
dst: &mut [f64], src: &[f64],
181+
simd_op: impl Fn(F64x8, F64x8) -> F64x8,
182+
scalar_op: impl Fn(&mut f64, f64),
183+
) {
184+
let n = dst.len().min(src.len());
185+
let mut i = 0;
186+
while i + 8 <= n {
187+
simd_op(F64x8::from_slice(&dst[i..]), F64x8::from_slice(&src[i..]))
188+
.copy_to_slice(&mut dst[i..]);
189+
i += 8;
190+
}
191+
while i < n { scalar_op(&mut dst[i], src[i]); i += 1; }
192+
}
193+
194+
// ═══════════════════════════════════════════════════════════════════
195+
// Tests
196+
// ═══════════════════════════════════════════════════════════════════
197+
198+
#[cfg(test)]
199+
mod tests {
200+
use super::*;
201+
202+
#[test]
203+
fn add_f32_aligned() {
204+
let a = vec![1.0f32; 32];
205+
let b = vec![2.0f32; 32];
206+
let c = add_f32(&a, &b);
207+
assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-6));
208+
}
209+
210+
#[test]
211+
fn add_f32_misaligned_tail() {
212+
let a = vec![1.0f32; 33];
213+
let b = vec![2.0f32; 33];
214+
let c = add_f32(&a, &b);
215+
assert_eq!(c.len(), 33);
216+
assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-6));
217+
}
218+
219+
#[test]
220+
fn mul_f32_inplace_works() {
221+
let mut dst = vec![2.0f32; 17];
222+
let src = vec![3.0f32; 17];
223+
mul_f32_inplace(&mut dst, &src);
224+
assert!(dst.iter().all(|&v| (v - 6.0).abs() < 1e-6));
225+
}
226+
227+
#[test]
228+
fn scale_f32_works() {
229+
let a = vec![4.0f32; 35];
230+
let b = scale_f32(&a, 0.5);
231+
assert!(b.iter().all(|&v| (v - 2.0).abs() < 1e-6));
232+
}
233+
234+
#[test]
235+
fn scale_f32_inplace_works() {
236+
let mut a = vec![10.0f32; 19];
237+
scale_f32_inplace(&mut a, 0.1);
238+
assert!(a.iter().all(|&v| (v - 1.0).abs() < 1e-5));
239+
}
240+
241+
#[test]
242+
fn add_scalar_f32_works() {
243+
let a = vec![1.0f32; 20];
244+
let b = add_scalar_f32(&a, 99.0);
245+
assert!(b.iter().all(|&v| (v - 100.0).abs() < 1e-6));
246+
}
247+
248+
#[test]
249+
fn sub_f32_works() {
250+
let c = sub_f32(&[5.0; 3], &[2.0; 3]);
251+
assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-6));
252+
}
253+
254+
#[test]
255+
fn div_f32_works() {
256+
let c = div_f32(&[6.0; 4], &[3.0; 4]);
257+
assert!(c.iter().all(|&v| (v - 2.0).abs() < 1e-6));
258+
}
259+
260+
#[test]
261+
fn add_f64_works() {
262+
let c = add_f64(&[1.0f64; 17], &[2.0f64; 17]);
263+
assert_eq!(c.len(), 17);
264+
assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-12));
265+
}
266+
267+
#[test]
268+
fn empty_slices() {
269+
assert!(add_f32(&[], &[]).is_empty());
270+
assert!(mul_f32(&[], &[]).is_empty());
271+
assert!(scale_f32(&[], 2.0).is_empty());
272+
}
273+
274+
#[test]
275+
fn mismatched_lengths_takes_min() {
276+
let c = add_f32(&[1.0; 10], &[2.0; 5]);
277+
assert_eq!(c.len(), 5);
278+
}
279+
}

0 commit comments

Comments
 (0)