Skip to content

Commit cfc45e2

Browse files
committed
inline, unroll, VMLA
1 parent 1b6156d commit cfc45e2

1 file changed

Lines changed: 18 additions & 24 deletions

File tree

pybricks/experimental/pb_module_experimental.c

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,30 @@
1010
#include "py/runtime.h"
1111
#include <math.h>
1212

13-
// High-Precision Constants
1413
static const float PI_F = 3.141592653589793f;
1514
static const float TWO_PI_F = 6.283185307179586f;
1615
static const float HALF_PI_F = 1.570796326794896f;
1716
static const float INV_TWO_PI_F = 0.159154943091895f;
1817

1918
// -----------------------------------------------------------------------------
20-
// Core Math Engines (Hardware Accelerated with VMLA)
19+
// Core Math Engines (Optimized for Auto-VMLA generation)
2120
// -----------------------------------------------------------------------------
2221

2322
static inline float fast_sin_internal(float theta) {
24-
// 1. Fast Range Reduction
2523
float x = theta * INV_TWO_PI_F;
2624
x = theta - (float)((int)(x + (x > 0 ? 0.5f : -0.5f))) * TWO_PI_F;
2725

28-
// 2. Symmetry Folding
2926
if (x > HALF_PI_F) { x = PI_F - x; }
3027
else if (x < -HALF_PI_F) { x = -PI_F - x; }
3128

3229
float x2 = x * x;
3330

34-
/* * 3. 7th-Degree Polynomial using ARM VMLA Intrinsics
35-
* Horner's Method: y = x * (1 + x^2 * (C1 + x^2 * (C2 + x^2 * C3)))
36-
* The compiler maps these to single-cycle hardware instructions.
37-
*/
31+
// By writing it as a = a + b * c, the GCC compiler for Cortex-M4
32+
// will automatically generate the VMLA (Multiply-Accumulate) instruction.
3833
float res = -0.000195152f;
39-
res = __builtin_arm_vmla_f32(0.008332152f, x2, res);
40-
res = __builtin_arm_vmla_f32(-0.166666567f, x2, res);
41-
res = __builtin_arm_vmla_f32(1.0f, x2, res);
34+
res = 0.008332152f + (x2 * res);
35+
res = -0.166666567f + (x2 * res);
36+
res = 1.0f + (x2 * res);
4237

4338
return x * res;
4439
}
@@ -52,20 +47,22 @@ static inline float fast_atan2_internal(float y, float x) {
5247

5348
if (abs_x >= abs_y) {
5449
float r = y / x;
55-
// Optimization: Use VMLA for the rational denominator
56-
angle = r * (1.0f / __builtin_arm_vmla_f32(1.0f, r * r, 0.28086f));
50+
// Simplified for auto-VMLA
51+
float den = 1.0f + (r * r * 0.28086f);
52+
angle = r * (1.0f / den);
5753
if (x < 0.0f) {
5854
angle += (y >= 0.0f) ? PI_F : -PI_F;
5955
}
6056
} else {
6157
float r = x / y;
62-
angle = (y > 0.0f ? HALF_PI_F : -HALF_PI_F) - r * (1.0f / __builtin_arm_vmla_f32(1.0f, r * r, 0.28086f));
58+
float den = 1.0f + (r * r * 0.28086f);
59+
angle = (y > 0.0f ? HALF_PI_F : -HALF_PI_F) - r * (1.0f / den);
6360
}
6461
return angle;
6562
}
6663

6764
// -----------------------------------------------------------------------------
68-
// Unrolled Benchmark (Slashing Loop Overhead)
65+
// Unrolled Benchmark
6966
// -----------------------------------------------------------------------------
7067

7168
static mp_obj_t experimental_benchmark_detailed(mp_obj_t n_in) {
@@ -74,8 +71,8 @@ static mp_obj_t experimental_benchmark_detailed(mp_obj_t n_in) {
7471
uint32_t t0, t1, t2, t3;
7572
float inv_n = 1.0f / (float)n;
7673

77-
// Benchmark Sin with 4x Loop Unrolling
7874
t0 = mp_hal_ticks_ms();
75+
// 4x Unrolling for Sin
7976
for (int32_t i = 0; i < n; i += 4) {
8077
result += fast_sin_internal((float)(i) * inv_n);
8178
result += fast_sin_internal((float)(i+1) * inv_n);
@@ -84,15 +81,12 @@ static mp_obj_t experimental_benchmark_detailed(mp_obj_t n_in) {
8481
}
8582

8683
t1 = mp_hal_ticks_ms();
84+
// 4x Unrolling for Cos
8785
for (int32_t i = 0; i < n; i += 4) {
88-
float v0 = ((float)(i) * inv_n) + HALF_PI_F;
89-
float v1 = ((float)(i+1) * inv_n) + HALF_PI_F;
90-
float v2 = ((float)(i+2) * inv_n) + HALF_PI_F;
91-
float v3 = ((float)(i+3) * inv_n) + HALF_PI_F;
92-
result += fast_sin_internal(v0);
93-
result += fast_sin_internal(v1);
94-
result += fast_sin_internal(v2);
95-
result += fast_sin_internal(v3);
86+
result += fast_sin_internal(((float)(i) * inv_n) + HALF_PI_F);
87+
result += fast_sin_internal(((float)(i+1) * inv_n) + HALF_PI_F);
88+
result += fast_sin_internal(((float)(i+2) * inv_n) + HALF_PI_F);
89+
result += fast_sin_internal(((float)(i+3) * inv_n) + HALF_PI_F);
9690
}
9791

9892
t2 = mp_hal_ticks_ms();

0 commit comments

Comments
 (0)