@@ -190,4 +190,58 @@ Simd<T, N> erfinv(Simd<T, N> a_) {
190190 }
191191}
192192
193+ /* *
194+ * Modified Bessel function of the first kind, order zero: I0(x).
195+ * Cephes polynomial approximation in two domains:
196+ * |x| <= 3.75 → polynomial in (x/3.75)^2
197+ * |x| > 3.75 → exp(|x|) / sqrt(|x|) * polynomial in (3.75/|x|)
198+ */
199+ template <typename T, int N>
200+ Simd<T, N> i0 (Simd<T, N> x_) {
201+ Simd<float , N> x = x_;
202+ Simd<float , N> y = abs (x);
203+
204+ // Branch 1: y <= 3.75
205+ auto small = [](Simd<float , N> y) {
206+ Simd<float , N> t = y / 3 .75f ;
207+ t = t * t;
208+ Simd<float , N> p (1 .0f );
209+ p = fma (t, Simd<float , N>(3 .5156229f ), p);
210+ // Horner evaluation of the inner polynomial
211+ Simd<float , N> r (0 .0045813f );
212+ r = fma (r, t, Simd<float , N>(0 .0360768f ));
213+ r = fma (r, t, Simd<float , N>(0 .2659732f ));
214+ r = fma (r, t, Simd<float , N>(1 .2067492f ));
215+ r = fma (r, t, Simd<float , N>(3 .0899424f ));
216+ r = fma (r, t, Simd<float , N>(3 .5156229f ));
217+ r = fma (r, t, Simd<float , N>(1 .0f ));
218+ return r;
219+ };
220+
221+ // Branch 2: y > 3.75
222+ auto large = [](Simd<float , N> y) {
223+ Simd<float , N> t = Simd<float , N>(3 .75f ) / y;
224+ Simd<float , N> p (0 .00392377f );
225+ p = fma (p, t, Simd<float , N>(-0 .01647633f ));
226+ p = fma (p, t, Simd<float , N>(0 .02635537f ));
227+ p = fma (p, t, Simd<float , N>(-0 .02057706f ));
228+ p = fma (p, t, Simd<float , N>(0 .00916281f ));
229+ p = fma (p, t, Simd<float , N>(-0 .00157565f ));
230+ p = fma (p, t, Simd<float , N>(0 .00225319f ));
231+ p = fma (p, t, Simd<float , N>(0 .01328592f ));
232+ p = fma (p, t, Simd<float , N>(0 .39894228f ));
233+ return (exp (y) / sqrt (y)) * p;
234+ };
235+
236+ if constexpr (N == 1 ) {
237+ if ((y <= 3 .75f ).value ) {
238+ return Simd<T, N>(small (y));
239+ } else {
240+ return Simd<T, N>(large (y));
241+ }
242+ } else {
243+ return Simd<T, N>(select (y <= 3 .75f , small (y), large (y)));
244+ }
245+ }
246+
193247} // namespace mlx::core::simd
0 commit comments