Skip to content

Commit 650a381

Browse files
committed
NEON SIMD for ARM
1 parent ec68bd7 commit 650a381

2 files changed

Lines changed: 88 additions & 4 deletions

File tree

CMakeLists.txt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,19 @@ endif()
291291
#
292292

293293
# Option to disable SIMD entirely
294-
option(USE_SIMD "Enable SIMD optimizations (SSE4.2/AVX2)" ON)
294+
option(USE_SIMD "Enable SIMD optimizations (SSE4.2/AVX2 on x86_64, NEON on ARM64)" ON)
295+
296+
# Check architecture
297+
# CMAKE_SYSTEM_PROCESSOR is "x86_64" on Intel Macs and Linux x86_64, "arm64"/"aarch64" on ARM
298+
set(IS_X86_64 FALSE)
299+
set(IS_ARM64 FALSE)
300+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64|amd64|i686|i386")
301+
set(IS_X86_64 TRUE)
302+
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64|aarch64|ARM64")
303+
set(IS_ARM64 TRUE)
304+
endif()
295305

296-
if(USE_SIMD AND NOT WIN32)
306+
if(USE_SIMD AND NOT WIN32 AND IS_X86_64)
297307
include(CheckCXXCompilerFlag)
298308

299309
# Check for AVX2 support
@@ -315,8 +325,14 @@ if(USE_SIMD AND NOT WIN32)
315325
add_compile_definitions(EIDOS_HAS_SSE42=1)
316326
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2")
317327
else()
318-
message(STATUS "SIMD: No SIMD support detected, using scalar fallback")
328+
message(STATUS "SIMD: No x86 SIMD support detected, using scalar fallback")
319329
endif()
330+
elseif(USE_SIMD AND NOT WIN32 AND IS_ARM64)
331+
# ARM64 NEON is always available on ARM64, no compiler flag needed
332+
message(STATUS "SIMD: ARM64 NEON support enabled")
333+
add_compile_definitions(EIDOS_HAS_NEON=1)
334+
elseif(USE_SIMD AND NOT WIN32)
335+
message(STATUS "SIMD: Unknown architecture (${CMAKE_SYSTEM_PROCESSOR}), using scalar fallback")
320336
elseif(USE_SIMD AND WIN32)
321337
# Windows/MSVC detection not yet implemented
322338
message(STATUS "SIMD: Windows SIMD detection not yet implemented, using scalar fallback")

eidos/eidos_simd.h

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
SIMD acceleration for Eidos math operations, independent of OpenMP.
2323
2424
This header provides vectorized implementations of common math operations
25-
using SSE4.2 or AVX2 intrinsics when available, with scalar fallbacks.
25+
using platform-specific SIMD intrinsics when available:
26+
- x86_64: SSE4.2 or AVX2 via <immintrin.h>
27+
- ARM64: NEON via <arm_neon.h>
28+
Falls back to scalar code when no SIMD is available.
2629
2730
*/
2831

@@ -42,6 +45,10 @@
4245
#include <smmintrin.h>
4346
#define EIDOS_SIMD_WIDTH 2 // 2 doubles per SSE register
4447
#define EIDOS_SIMD_FLOAT_WIDTH 4 // 4 floats per SSE register
48+
#elif defined(EIDOS_HAS_NEON)
49+
#include <arm_neon.h>
50+
#define EIDOS_SIMD_WIDTH 2 // 2 doubles per NEON register
51+
#define EIDOS_SIMD_FLOAT_WIDTH 4 // 4 floats per NEON register
4552
#else
4653
#define EIDOS_SIMD_WIDTH 1 // Scalar fallback
4754
#define EIDOS_SIMD_FLOAT_WIDTH 1
@@ -78,6 +85,14 @@ inline void sqrt_float64(const double *input, double *output, int64_t count)
7885
__m128d r = _mm_sqrt_pd(v);
7986
_mm_storeu_pd(&output[i], r);
8087
}
88+
#elif defined(EIDOS_HAS_NEON)
89+
// Process 2 doubles at a time
90+
for (; i + 2 <= count; i += 2)
91+
{
92+
float64x2_t v = vld1q_f64(&input[i]);
93+
float64x2_t r = vsqrtq_f64(v);
94+
vst1q_f64(&output[i], r);
95+
}
8196
#endif
8297

8398
// Scalar remainder
@@ -109,6 +124,13 @@ inline void abs_float64(const double *input, double *output, int64_t count)
109124
__m128d r = _mm_andnot_pd(sign_mask, v);
110125
_mm_storeu_pd(&output[i], r);
111126
}
127+
#elif defined(EIDOS_HAS_NEON)
128+
for (; i + 2 <= count; i += 2)
129+
{
130+
float64x2_t v = vld1q_f64(&input[i]);
131+
float64x2_t r = vabsq_f64(v);
132+
vst1q_f64(&output[i], r);
133+
}
112134
#endif
113135

114136
for (; i < count; i++)
@@ -136,6 +158,13 @@ inline void floor_float64(const double *input, double *output, int64_t count)
136158
__m128d r = _mm_floor_pd(v);
137159
_mm_storeu_pd(&output[i], r);
138160
}
161+
#elif defined(EIDOS_HAS_NEON)
162+
for (; i + 2 <= count; i += 2)
163+
{
164+
float64x2_t v = vld1q_f64(&input[i]);
165+
float64x2_t r = vrndmq_f64(v); // Round toward minus infinity (floor)
166+
vst1q_f64(&output[i], r);
167+
}
139168
#endif
140169

141170
for (; i < count; i++)
@@ -163,6 +192,13 @@ inline void ceil_float64(const double *input, double *output, int64_t count)
163192
__m128d r = _mm_ceil_pd(v);
164193
_mm_storeu_pd(&output[i], r);
165194
}
195+
#elif defined(EIDOS_HAS_NEON)
196+
for (; i + 2 <= count; i += 2)
197+
{
198+
float64x2_t v = vld1q_f64(&input[i]);
199+
float64x2_t r = vrndpq_f64(v); // Round toward plus infinity (ceil)
200+
vst1q_f64(&output[i], r);
201+
}
166202
#endif
167203

168204
for (; i < count; i++)
@@ -190,6 +226,13 @@ inline void trunc_float64(const double *input, double *output, int64_t count)
190226
__m128d r = _mm_round_pd(v, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
191227
_mm_storeu_pd(&output[i], r);
192228
}
229+
#elif defined(EIDOS_HAS_NEON)
230+
for (; i + 2 <= count; i += 2)
231+
{
232+
float64x2_t v = vld1q_f64(&input[i]);
233+
float64x2_t r = vrndq_f64(v); // Round toward zero (truncate)
234+
vst1q_f64(&output[i], r);
235+
}
193236
#endif
194237

195238
for (; i < count; i++)
@@ -217,6 +260,13 @@ inline void round_float64(const double *input, double *output, int64_t count)
217260
__m128d r = _mm_round_pd(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
218261
_mm_storeu_pd(&output[i], r);
219262
}
263+
#elif defined(EIDOS_HAS_NEON)
264+
for (; i + 2 <= count; i += 2)
265+
{
266+
float64x2_t v = vld1q_f64(&input[i]);
267+
float64x2_t r = vrndaq_f64(v); // Round to nearest, ties away from zero
268+
vst1q_f64(&output[i], r);
269+
}
220270
#endif
221271

222272
for (; i < count; i++)
@@ -298,6 +348,15 @@ inline double sum_float64(const double *input, int64_t count)
298348
__m128d shuf = _mm_shuffle_pd(vsum, vsum, 1);
299349
vsum = _mm_add_sd(vsum, shuf);
300350
sum = _mm_cvtsd_f64(vsum);
351+
#elif defined(EIDOS_HAS_NEON)
352+
float64x2_t vsum = vdupq_n_f64(0.0);
353+
for (; i + 2 <= count; i += 2)
354+
{
355+
float64x2_t v = vld1q_f64(&input[i]);
356+
vsum = vaddq_f64(vsum, v);
357+
}
358+
// Horizontal sum of 2 doubles
359+
sum = vaddvq_f64(vsum);
301360
#endif
302361

303362
// Scalar remainder
@@ -339,6 +398,15 @@ inline double product_float64(const double *input, int64_t count)
339398
__m128d shuf = _mm_shuffle_pd(vprod, vprod, 1);
340399
vprod = _mm_mul_sd(vprod, shuf);
341400
prod = _mm_cvtsd_f64(vprod);
401+
#elif defined(EIDOS_HAS_NEON)
402+
float64x2_t vprod = vdupq_n_f64(1.0);
403+
for (; i + 2 <= count; i += 2)
404+
{
405+
float64x2_t v = vld1q_f64(&input[i]);
406+
vprod = vmulq_f64(vprod, v);
407+
}
408+
// Horizontal product of 2 doubles
409+
prod = vgetq_lane_f64(vprod, 0) * vgetq_lane_f64(vprod, 1);
342410
#endif
343411

344412
for (; i < count; i++)

0 commit comments

Comments
 (0)