|
| 1 | + |
| 2 | +#include "float4_dispatch.h" |
| 3 | +#include <smmintrin.h> // SSE41 intrinsics |
| 4 | + |
| 5 | +namespace |
| 6 | +{ |
| 7 | + typedef __m128 f32x4; |
| 8 | + |
| 9 | + // Load 4 floats from memory into a SIMD register |
| 10 | + inline f32x4 v_load(const float* p) { return _mm_loadu_ps(p); } |
| 11 | + |
| 12 | + // Store 4 floats from SIMD register back to memory |
| 13 | + inline void v_store(float* dst, f32x4 v) { _mm_storeu_ps(dst, v); } |
| 14 | + |
| 15 | + // Broadcast a single float across all 4 lanes |
| 16 | + inline f32x4 v_set1(float s) { return _mm_set1_ps(s); } |
| 17 | + |
| 18 | + // Element-wise multiply |
| 19 | + inline f32x4 v_mul(f32x4 a, f32x4 b) { return _mm_mul_ps(a, b); } |
| 20 | + |
| 21 | + // Element-wise divide |
| 22 | + inline f32x4 v_div(f32x4 a, f32x4 b) { return _mm_div_ps(a, b); } |
| 23 | + |
| 24 | + // Element-wise add |
| 25 | + inline f32x4 v_add(f32x4 a, f32x4 b) { return _mm_add_ps(a, b); } |
| 26 | + |
| 27 | + // Element-wise subtract |
| 28 | + inline f32x4 v_sub(f32x4 a, f32x4 b) { return _mm_sub_ps(a, b); } |
| 29 | + |
| 30 | + // Horizontal sum of all 4 elements (for dot product, length, etc.) |
| 31 | + inline float v_hadd4(f32x4 a) |
| 32 | + { |
| 33 | + __m128 t1 = _mm_hadd_ps(a, a); // sums pairs: [a0+a1, a2+a3, ...] |
| 34 | + __m128 t2 = _mm_hadd_ps(t1, t1); // sums again: first element = a0+a1+a2+a3 |
| 35 | + return _mm_cvtss_f32(t2); // extract first element |
| 36 | + } |
| 37 | + |
| 38 | + // specialized dot product for SSE4.1 |
| 39 | + float float4_dot_sse41(const float* a, const float* b) |
| 40 | + { |
| 41 | + f32x4 va = _mm_loadu_ps(a); |
| 42 | + f32x4 vb = _mm_loadu_ps(b); |
| 43 | + __m128 dp = _mm_dp_ps(va, vb, 0xF1); // multiply all 4, sum all 4, lowest lane |
| 44 | + return _mm_cvtss_f32(dp); |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +#include "float4_impl.inl" |
| 49 | + |
| 50 | +namespace math_backend::float4::dispatch |
| 51 | +{ |
| 52 | + // Install SSE41 backend |
| 53 | + void install_sse41() |
| 54 | + { |
| 55 | + gFloat4.add = float4_add_impl; |
| 56 | + gFloat4.sub = float4_sub_impl; |
| 57 | + gFloat4.mul = float4_mul_impl; |
| 58 | + gFloat4.mul_scalar = float4_mul_scalar_impl; |
| 59 | + gFloat4.div = float4_div_impl; |
| 60 | + gFloat4.div_scalar = float4_div_scalar_impl; |
| 61 | + gFloat4.dot = float4_dot_sse41; |
| 62 | + gFloat4.length = float4_length_impl; |
| 63 | + gFloat4.lengthSquared = float4_length_squared_impl; |
| 64 | + gFloat4.normalize = float4_normalize_impl; |
| 65 | + gFloat4.lerp = float4_lerp_impl; |
| 66 | + } |
| 67 | +} |
0 commit comments