Skip to content

Commit 1965f8f

Browse files
committed
feat: native i0 (modified Bessel function) and kaiser window
1 parent 8cd377b commit 1965f8f

File tree

15 files changed

+414
-0
lines changed

15 files changed

+414
-0
lines changed

mlx/backend/cpu/simd/math.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,58 @@ Simd<T, N> erfinv(Simd<T, N> a_) {
190190
}
191191
}
192192

193+
/**
194+
* Modified Bessel function of the first kind, order zero: I0(x).
195+
* Cephes polynomial approximation in two domains:
196+
* |x| <= 3.75 → polynomial in (x/3.75)^2
197+
* |x| > 3.75 → exp(|x|) / sqrt(|x|) * polynomial in (3.75/|x|)
198+
*/
199+
template <typename T, int N>
200+
Simd<T, N> i0(Simd<T, N> x_) {
201+
Simd<float, N> x = x_;
202+
Simd<float, N> y = abs(x);
203+
204+
// Branch 1: y <= 3.75
205+
auto small = [](Simd<float, N> y) {
206+
Simd<float, N> t = y / 3.75f;
207+
t = t * t;
208+
Simd<float, N> p(1.0f);
209+
p = fma(t, Simd<float, N>(3.5156229f), p);
210+
// Horner evaluation of the inner polynomial
211+
Simd<float, N> r(0.0045813f);
212+
r = fma(r, t, Simd<float, N>(0.0360768f));
213+
r = fma(r, t, Simd<float, N>(0.2659732f));
214+
r = fma(r, t, Simd<float, N>(1.2067492f));
215+
r = fma(r, t, Simd<float, N>(3.0899424f));
216+
r = fma(r, t, Simd<float, N>(3.5156229f));
217+
r = fma(r, t, Simd<float, N>(1.0f));
218+
return r;
219+
};
220+
221+
// Branch 2: y > 3.75
222+
auto large = [](Simd<float, N> y) {
223+
Simd<float, N> t = Simd<float, N>(3.75f) / y;
224+
Simd<float, N> p(0.00392377f);
225+
p = fma(p, t, Simd<float, N>(-0.01647633f));
226+
p = fma(p, t, Simd<float, N>(0.02635537f));
227+
p = fma(p, t, Simd<float, N>(-0.02057706f));
228+
p = fma(p, t, Simd<float, N>(0.00916281f));
229+
p = fma(p, t, Simd<float, N>(-0.00157565f));
230+
p = fma(p, t, Simd<float, N>(0.00225319f));
231+
p = fma(p, t, Simd<float, N>(0.01328592f));
232+
p = fma(p, t, Simd<float, N>(0.39894228f));
233+
return (exp(y) / sqrt(y)) * p;
234+
};
235+
236+
if constexpr (N == 1) {
237+
if ((y <= 3.75f).value) {
238+
return Simd<T, N>(small(y));
239+
} else {
240+
return Simd<T, N>(large(y));
241+
}
242+
} else {
243+
return Simd<T, N>(select(y <= 3.75f, small(y), large(y)));
244+
}
245+
}
246+
193247
} // namespace mlx::core::simd

mlx/backend/cpu/unary.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
103103
unary_real_fp(in, out, detail::ErfInv(), stream());
104104
}
105105

106+
void I0::eval_cpu(const std::vector<array>& inputs, array& out) {
107+
assert(inputs.size() == 1);
108+
const auto& in = inputs[0];
109+
unary_real_fp(in, out, detail::I0(), stream());
110+
}
111+
106112
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
107113
assert(inputs.size() == 1);
108114
const auto& in = inputs[0];

mlx/backend/cpu/unary_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ DEFAULT_OP(ErfInv, erfinv)
4444
DEFAULT_OP(Exp, exp)
4545
DEFAULT_OP(Expm1, expm1)
4646
DEFAULT_OP(Floor, floor);
47+
DEFAULT_OP(I0, i0)
4748
DEFAULT_OP(Log, log);
4849
DEFAULT_OP(Log2, log2);
4950
DEFAULT_OP(Log10, log10);

mlx/backend/metal/kernels/i0.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#pragma once
4+
#include <metal_math>
5+
6+
/*
7+
* Modified Bessel function of the first kind, order zero: I0(x).
8+
* Uses the Cephes polynomial approximation in two domains.
9+
*
10+
* Domain 1: |x| <= 3.75 → polynomial in (x/3.75)^2
11+
* Domain 2: |x| > 3.75 → exp(|x|) / sqrt(|x|) * polynomial in (3.75/|x|)
12+
*
13+
* Reference: Cephes Math Library (netlib.org/cephes)
14+
*/
15+
float i0_impl(float x) {
16+
float y = metal::abs(x);
17+
18+
if (y <= 3.75f) {
19+
float t = y / 3.75f;
20+
t = t * t;
21+
return 1.0f +
22+
t * (3.5156229f +
23+
t * (3.0899424f +
24+
t * (1.2067492f +
25+
t * (0.2659732f + t * (0.0360768f + t * 0.0045813f)))));
26+
} else {
27+
float t = 3.75f / y;
28+
float p = 0.00392377f;
29+
p = metal::fma(p, t, -0.01647633f);
30+
p = metal::fma(p, t, 0.02635537f);
31+
p = metal::fma(p, t, -0.02057706f);
32+
p = metal::fma(p, t, 0.00916281f);
33+
p = metal::fma(p, t, -0.00157565f);
34+
p = metal::fma(p, t, 0.00225319f);
35+
p = metal::fma(p, t, 0.01328592f);
36+
p = metal::fma(p, t, 0.39894228f);
37+
return (metal::precise::exp(y) / metal::precise::sqrt(y)) * p;
38+
}
39+
}

mlx/backend/metal/kernels/unary.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ instantiate_unary_types(Negative)
6767
instantiate_unary_float(Sigmoid)
6868
instantiate_unary_float(Erf)
6969
instantiate_unary_float(ErfInv)
70+
instantiate_unary_float(I0)
7071
instantiate_unary_types(Sign)
7172
instantiate_unary_float(Sin)
7273
instantiate_unary_float(Sinh)

mlx/backend/metal/kernels/unary_ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlx/backend/metal/kernels/erf.h"
1010
#include "mlx/backend/metal/kernels/expm1f.h"
1111
#include "mlx/backend/metal/kernels/fp8.h"
12+
#include "mlx/backend/metal/kernels/i0.h"
1213

1314
namespace {
1415
constant float inf = metal::numeric_limits<float>::infinity();
@@ -174,6 +175,13 @@ struct ErfInv {
174175
};
175176
};
176177

178+
struct I0 {
179+
template <typename T>
180+
T operator()(T x) {
181+
return static_cast<T>(i0_impl(static_cast<float>(x)));
182+
};
183+
};
184+
177185
struct Exp {
178186
template <typename T>
179187
T operator()(T x) {

mlx/backend/metal/unary.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ UNARY_GPU(Cos)
125125
UNARY_GPU(Cosh)
126126
UNARY_GPU(Erf)
127127
UNARY_GPU(ErfInv)
128+
UNARY_GPU(I0)
128129
UNARY_GPU(Exp)
129130
UNARY_GPU(Expm1)
130131
UNARY_GPU(Imag)

mlx/backend/no_cpu/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ NO_CPU_MULTI(Eigh)
6060
NO_CPU(Equal)
6161
NO_CPU(Erf)
6262
NO_CPU(ErfInv)
63+
NO_CPU(I0)
6364
NO_CPU(Exp)
6465
NO_CPU(ExpandDims)
6566
NO_CPU(Expm1)

mlx/backend/no_gpu/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ NO_GPU(Remainder)
8787
NO_GPU(Equal)
8888
NO_GPU(Erf)
8989
NO_GPU(ErfInv)
90+
NO_GPU(I0)
9091
NO_GPU(Exp)
9192
NO_GPU(ExpandDims)
9293
NO_GPU(Expm1)

mlx/ops.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3011,6 +3011,35 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) {
30113011
{astype(a, dtype, s)});
30123012
}
30133013

3014+
array i0(const array& a, StreamOrDevice s /* = {} */) {
3015+
auto dtype = at_least_float(a.dtype());
3016+
return array(
3017+
a.shape(),
3018+
dtype,
3019+
std::make_shared<I0>(to_stream(s)),
3020+
{astype(a, dtype, s)});
3021+
}
3022+
3023+
array kaiser(int M, float beta, StreamOrDevice s /* = {} */) {
3024+
if (M < 1) {
3025+
return array({});
3026+
}
3027+
if (M == 1) {
3028+
return ones({1}, float32, s);
3029+
}
3030+
3031+
// w(n) = I0(beta * sqrt(1 - ((2n/(M-1)) - 1)^2)) / I0(beta)
3032+
auto n = arange(0, M, float32, s);
3033+
auto alpha = array((M - 1) / 2.0f, float32);
3034+
auto x = divide(subtract(n, alpha, s), alpha, s); // (2n/(M-1)) - 1
3035+
auto arg = multiply( // beta * sqrt(1 - x^2)
3036+
array(beta, float32),
3037+
sqrt(subtract(array(1.0f, float32), square(x, s), s), s),
3038+
s);
3039+
auto denom = i0(array(beta, float32), s);
3040+
return divide(i0(arg, s), denom, s);
3041+
}
3042+
30143043
array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
30153044
return array(
30163045
a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});

0 commit comments

Comments
 (0)