|
| 1 | +//! Slice-level elementwise ops built on the polyfill SIMD types. |
| 2 | +//! |
| 3 | +//! Every function uses `crate::simd::F32x16` (or the appropriate type), |
| 4 | +//! which is already dispatched: AVX-512 → AVX2 → NEON → scalar. |
| 5 | +//! These ops inherit that dispatch — no platform-specific code here. |
| 6 | +//! |
| 7 | +//! Re-exported flat through `ndarray::simd::add_f32`, etc. |
| 8 | +
|
| 9 | +use crate::simd::{F32x16, F64x8}; |
| 10 | + |
| 11 | +// ═══════════════════════════════════════════════════════════════════ |
| 12 | +// f32 binary ops (out-of-place) |
| 13 | +// ═══════════════════════════════════════════════════════════════════ |
| 14 | + |
| 15 | +/// Elementwise add: `out[i] = a[i] + b[i]`. |
| 16 | +pub fn add_f32(a: &[f32], b: &[f32]) -> Vec<f32> { |
| 17 | + binary_f32(a, b, |x, y| x + y, |x, y| x + y) |
| 18 | +} |
| 19 | + |
| 20 | +/// Elementwise subtract: `out[i] = a[i] - b[i]`. |
| 21 | +pub fn sub_f32(a: &[f32], b: &[f32]) -> Vec<f32> { |
| 22 | + binary_f32(a, b, |x, y| x - y, |x, y| x - y) |
| 23 | +} |
| 24 | + |
| 25 | +/// Elementwise multiply: `out[i] = a[i] * b[i]`. |
| 26 | +pub fn mul_f32(a: &[f32], b: &[f32]) -> Vec<f32> { |
| 27 | + binary_f32(a, b, |x, y| x * y, |x, y| x * y) |
| 28 | +} |
| 29 | + |
| 30 | +/// Elementwise divide: `out[i] = a[i] / b[i]`. |
| 31 | +pub fn div_f32(a: &[f32], b: &[f32]) -> Vec<f32> { |
| 32 | + binary_f32(a, b, |x, y| x / y, |x, y| x / y) |
| 33 | +} |
| 34 | + |
| 35 | +// ═══════════════════════════════════════════════════════════════════ |
| 36 | +// f32 inplace ops |
| 37 | +// ═══════════════════════════════════════════════════════════════════ |
| 38 | + |
| 39 | +/// Inplace add: `dst[i] += src[i]`. |
| 40 | +pub fn add_f32_inplace(dst: &mut [f32], src: &[f32]) { |
| 41 | + inplace_f32(dst, src, |d, s| d + s, |d, s| *d += s) |
| 42 | +} |
| 43 | + |
| 44 | +/// Inplace subtract: `dst[i] -= src[i]`. |
| 45 | +pub fn sub_f32_inplace(dst: &mut [f32], src: &[f32]) { |
| 46 | + inplace_f32(dst, src, |d, s| d - s, |d, s| *d -= s) |
| 47 | +} |
| 48 | + |
| 49 | +/// Inplace multiply: `dst[i] *= src[i]`. |
| 50 | +pub fn mul_f32_inplace(dst: &mut [f32], src: &[f32]) { |
| 51 | + inplace_f32(dst, src, |d, s| d * s, |d, s| *d *= s) |
| 52 | +} |
| 53 | + |
| 54 | +/// Inplace divide: `dst[i] /= src[i]`. |
| 55 | +pub fn div_f32_inplace(dst: &mut [f32], src: &[f32]) { |
| 56 | + inplace_f32(dst, src, |d, s| d / s, |d, s| *d /= s) |
| 57 | +} |
| 58 | + |
| 59 | +// ═══════════════════════════════════════════════════════════════════ |
| 60 | +// f32 scalar ops |
| 61 | +// ═══════════════════════════════════════════════════════════════════ |
| 62 | + |
| 63 | +/// Scalar multiply: `out[i] = a[i] * scalar`. |
| 64 | +pub fn scale_f32(a: &[f32], scalar: f32) -> Vec<f32> { |
| 65 | + let s = F32x16::splat(scalar); |
| 66 | + let n = a.len(); |
| 67 | + let mut out = vec![0.0f32; n]; |
| 68 | + let mut i = 0; |
| 69 | + while i + 16 <= n { |
| 70 | + (F32x16::from_slice(&a[i..]) * s).copy_to_slice(&mut out[i..]); |
| 71 | + i += 16; |
| 72 | + } |
| 73 | + while i < n { out[i] = a[i] * scalar; i += 1; } |
| 74 | + out |
| 75 | +} |
| 76 | + |
| 77 | +/// Scalar add: `out[i] = a[i] + scalar`. |
| 78 | +pub fn add_scalar_f32(a: &[f32], scalar: f32) -> Vec<f32> { |
| 79 | + let s = F32x16::splat(scalar); |
| 80 | + let n = a.len(); |
| 81 | + let mut out = vec![0.0f32; n]; |
| 82 | + let mut i = 0; |
| 83 | + while i + 16 <= n { |
| 84 | + (F32x16::from_slice(&a[i..]) + s).copy_to_slice(&mut out[i..]); |
| 85 | + i += 16; |
| 86 | + } |
| 87 | + while i < n { out[i] = a[i] + scalar; i += 1; } |
| 88 | + out |
| 89 | +} |
| 90 | + |
| 91 | +/// Inplace scalar multiply: `a[i] *= scalar`. |
| 92 | +pub fn scale_f32_inplace(a: &mut [f32], scalar: f32) { |
| 93 | + let s = F32x16::splat(scalar); |
| 94 | + let n = a.len(); |
| 95 | + let mut i = 0; |
| 96 | + while i + 16 <= n { |
| 97 | + (F32x16::from_slice(&a[i..]) * s).copy_to_slice(&mut a[i..]); |
| 98 | + i += 16; |
| 99 | + } |
| 100 | + while i < n { a[i] *= scalar; i += 1; } |
| 101 | +} |
| 102 | + |
| 103 | +// ═══════════════════════════════════════════════════════════════════ |
| 104 | +// f64 binary ops |
| 105 | +// ═══════════════════════════════════════════════════════════════════ |
| 106 | + |
| 107 | +/// Elementwise add f64: `out[i] = a[i] + b[i]`. |
| 108 | +pub fn add_f64(a: &[f64], b: &[f64]) -> Vec<f64> { |
| 109 | + binary_f64(a, b, |x, y| x + y, |x, y| x + y) |
| 110 | +} |
| 111 | + |
| 112 | +/// Elementwise multiply f64: `out[i] = a[i] * b[i]`. |
| 113 | +pub fn mul_f64(a: &[f64], b: &[f64]) -> Vec<f64> { |
| 114 | + binary_f64(a, b, |x, y| x * y, |x, y| x * y) |
| 115 | +} |
| 116 | + |
| 117 | +/// Inplace add f64: `dst[i] += src[i]`. |
| 118 | +pub fn add_f64_inplace(dst: &mut [f64], src: &[f64]) { |
| 119 | + inplace_f64(dst, src, |d, s| d + s, |d, s| *d += s) |
| 120 | +} |
| 121 | + |
| 122 | +// ═══════════════════════════════════════════════════════════════════ |
| 123 | +// Internal dispatch helpers |
| 124 | +// ═══════════════════════════════════════════════════════════════════ |
| 125 | + |
| 126 | +#[inline] |
| 127 | +fn binary_f32( |
| 128 | + a: &[f32], b: &[f32], |
| 129 | + simd_op: impl Fn(F32x16, F32x16) -> F32x16, |
| 130 | + scalar_op: impl Fn(f32, f32) -> f32, |
| 131 | +) -> Vec<f32> { |
| 132 | + let n = a.len().min(b.len()); |
| 133 | + let mut out = vec![0.0f32; n]; |
| 134 | + let mut i = 0; |
| 135 | + while i + 16 <= n { |
| 136 | + simd_op(F32x16::from_slice(&a[i..]), F32x16::from_slice(&b[i..])) |
| 137 | + .copy_to_slice(&mut out[i..]); |
| 138 | + i += 16; |
| 139 | + } |
| 140 | + while i < n { out[i] = scalar_op(a[i], b[i]); i += 1; } |
| 141 | + out |
| 142 | +} |
| 143 | + |
| 144 | +#[inline] |
| 145 | +fn inplace_f32( |
| 146 | + dst: &mut [f32], src: &[f32], |
| 147 | + simd_op: impl Fn(F32x16, F32x16) -> F32x16, |
| 148 | + scalar_op: impl Fn(&mut f32, f32), |
| 149 | +) { |
| 150 | + let n = dst.len().min(src.len()); |
| 151 | + let mut i = 0; |
| 152 | + while i + 16 <= n { |
| 153 | + simd_op(F32x16::from_slice(&dst[i..]), F32x16::from_slice(&src[i..])) |
| 154 | + .copy_to_slice(&mut dst[i..]); |
| 155 | + i += 16; |
| 156 | + } |
| 157 | + while i < n { scalar_op(&mut dst[i], src[i]); i += 1; } |
| 158 | +} |
| 159 | + |
| 160 | +#[inline] |
| 161 | +fn binary_f64( |
| 162 | + a: &[f64], b: &[f64], |
| 163 | + simd_op: impl Fn(F64x8, F64x8) -> F64x8, |
| 164 | + scalar_op: impl Fn(f64, f64) -> f64, |
| 165 | +) -> Vec<f64> { |
| 166 | + let n = a.len().min(b.len()); |
| 167 | + let mut out = vec![0.0f64; n]; |
| 168 | + let mut i = 0; |
| 169 | + while i + 8 <= n { |
| 170 | + simd_op(F64x8::from_slice(&a[i..]), F64x8::from_slice(&b[i..])) |
| 171 | + .copy_to_slice(&mut out[i..]); |
| 172 | + i += 8; |
| 173 | + } |
| 174 | + while i < n { out[i] = scalar_op(a[i], b[i]); i += 1; } |
| 175 | + out |
| 176 | +} |
| 177 | + |
| 178 | +#[inline] |
| 179 | +fn inplace_f64( |
| 180 | + dst: &mut [f64], src: &[f64], |
| 181 | + simd_op: impl Fn(F64x8, F64x8) -> F64x8, |
| 182 | + scalar_op: impl Fn(&mut f64, f64), |
| 183 | +) { |
| 184 | + let n = dst.len().min(src.len()); |
| 185 | + let mut i = 0; |
| 186 | + while i + 8 <= n { |
| 187 | + simd_op(F64x8::from_slice(&dst[i..]), F64x8::from_slice(&src[i..])) |
| 188 | + .copy_to_slice(&mut dst[i..]); |
| 189 | + i += 8; |
| 190 | + } |
| 191 | + while i < n { scalar_op(&mut dst[i], src[i]); i += 1; } |
| 192 | +} |
| 193 | + |
| 194 | +// ═══════════════════════════════════════════════════════════════════ |
| 195 | +// Tests |
| 196 | +// ═══════════════════════════════════════════════════════════════════ |
| 197 | + |
| 198 | +#[cfg(test)] |
| 199 | +mod tests { |
| 200 | + use super::*; |
| 201 | + |
| 202 | + #[test] |
| 203 | + fn add_f32_aligned() { |
| 204 | + let a = vec![1.0f32; 32]; |
| 205 | + let b = vec![2.0f32; 32]; |
| 206 | + let c = add_f32(&a, &b); |
| 207 | + assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-6)); |
| 208 | + } |
| 209 | + |
| 210 | + #[test] |
| 211 | + fn add_f32_misaligned_tail() { |
| 212 | + let a = vec![1.0f32; 33]; |
| 213 | + let b = vec![2.0f32; 33]; |
| 214 | + let c = add_f32(&a, &b); |
| 215 | + assert_eq!(c.len(), 33); |
| 216 | + assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-6)); |
| 217 | + } |
| 218 | + |
| 219 | + #[test] |
| 220 | + fn mul_f32_inplace_works() { |
| 221 | + let mut dst = vec![2.0f32; 17]; |
| 222 | + let src = vec![3.0f32; 17]; |
| 223 | + mul_f32_inplace(&mut dst, &src); |
| 224 | + assert!(dst.iter().all(|&v| (v - 6.0).abs() < 1e-6)); |
| 225 | + } |
| 226 | + |
| 227 | + #[test] |
| 228 | + fn scale_f32_works() { |
| 229 | + let a = vec![4.0f32; 35]; |
| 230 | + let b = scale_f32(&a, 0.5); |
| 231 | + assert!(b.iter().all(|&v| (v - 2.0).abs() < 1e-6)); |
| 232 | + } |
| 233 | + |
| 234 | + #[test] |
| 235 | + fn scale_f32_inplace_works() { |
| 236 | + let mut a = vec![10.0f32; 19]; |
| 237 | + scale_f32_inplace(&mut a, 0.1); |
| 238 | + assert!(a.iter().all(|&v| (v - 1.0).abs() < 1e-5)); |
| 239 | + } |
| 240 | + |
| 241 | + #[test] |
| 242 | + fn add_scalar_f32_works() { |
| 243 | + let a = vec![1.0f32; 20]; |
| 244 | + let b = add_scalar_f32(&a, 99.0); |
| 245 | + assert!(b.iter().all(|&v| (v - 100.0).abs() < 1e-6)); |
| 246 | + } |
| 247 | + |
| 248 | + #[test] |
| 249 | + fn sub_f32_works() { |
| 250 | + let c = sub_f32(&[5.0; 3], &[2.0; 3]); |
| 251 | + assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-6)); |
| 252 | + } |
| 253 | + |
| 254 | + #[test] |
| 255 | + fn div_f32_works() { |
| 256 | + let c = div_f32(&[6.0; 4], &[3.0; 4]); |
| 257 | + assert!(c.iter().all(|&v| (v - 2.0).abs() < 1e-6)); |
| 258 | + } |
| 259 | + |
| 260 | + #[test] |
| 261 | + fn add_f64_works() { |
| 262 | + let c = add_f64(&[1.0f64; 17], &[2.0f64; 17]); |
| 263 | + assert_eq!(c.len(), 17); |
| 264 | + assert!(c.iter().all(|&v| (v - 3.0).abs() < 1e-12)); |
| 265 | + } |
| 266 | + |
| 267 | + #[test] |
| 268 | + fn empty_slices() { |
| 269 | + assert!(add_f32(&[], &[]).is_empty()); |
| 270 | + assert!(mul_f32(&[], &[]).is_empty()); |
| 271 | + assert!(scale_f32(&[], 2.0).is_empty()); |
| 272 | + } |
| 273 | + |
| 274 | + #[test] |
| 275 | + fn mismatched_lengths_takes_min() { |
| 276 | + let c = add_f32(&[1.0; 10], &[2.0; 5]); |
| 277 | + assert_eq!(c.len(), 5); |
| 278 | + } |
| 279 | +} |
0 commit comments