@@ -56,6 +56,84 @@ HWY_INLINE void ReduceAngleTan(D d, V ang, V& x_red, V& sign) {
5656
5757} // namespace impl
5858
59+ namespace impl {
60+
61+ template <class T >
62+ struct FastExpImpl {};
63+
64+ template <>
65+ struct FastExpImpl <float > {
66+ // Rounds float toward zero and returns as int32_t.
67+ template <class D , class V >
68+ HWY_INLINE Vec<Rebind<int32_t , D>> ToInt32 (D /* unused*/ , V x) {
69+ return ConvertInRangeTo (Rebind<int32_t , D>(), x);
70+ }
71+
72+ // Computes 2^x, where x is an integer.
73+ template <class D , class VI32 >
74+ HWY_INLINE Vec<D> Pow2I (D d, VI32 x) {
75+ const Rebind<int32_t , D> di32;
76+ const VI32 kOffset = Set (di32, 0x7F );
77+ return BitCast (d, ShiftLeft<23 >(Add (x, kOffset )));
78+ }
79+
80+ // Sets the exponent of 'x' to 2^e.
81+ template <class D , class V , class VI32 >
82+ HWY_INLINE V LoadExpShortRange (D d, V x, VI32 e) {
83+ const VI32 y = ShiftRight<1 >(e);
84+ return Mul (Mul (x, Pow2I (d, y)), Pow2I (d, Sub (e, y)));
85+ }
86+
87+ template <class D , class V , class VI32 >
88+ HWY_INLINE V ExpReduce (D d, V x, VI32 q) {
89+ // kMinusLn2 ~= -ln(2)
90+ const V kMinusLn2 = Set (d, -0 .69314718056f );
91+
92+ // Extended precision modular arithmetic.
93+ const V qf = ConvertTo (d, q);
94+ return MulAdd (qf, kMinusLn2 , x);
95+ }
96+ };
97+
98+ #if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64
99+ template <>
100+ struct FastExpImpl <double > {
101+ // Rounds double toward zero and returns as int32_t.
102+ template <class D , class V >
103+ HWY_INLINE Vec<Rebind<int32_t , D>> ToInt32 (D /* unused*/ , V x) {
104+ return DemoteInRangeTo (Rebind<int32_t , D>(), x);
105+ }
106+
107+ // Computes 2^x, where x is an integer.
108+ template <class D , class VI32 >
109+ HWY_INLINE Vec<D> Pow2I (D d, VI32 x) {
110+ const Rebind<int32_t , D> di32;
111+ const Rebind<int64_t , D> di64;
112+ const VI32 kOffset = Set (di32, 0x3FF );
113+ return BitCast (d, ShiftLeft<52 >(PromoteTo (di64, Add (x, kOffset ))));
114+ }
115+
116+ // Sets the exponent of 'x' to 2^e.
117+ template <class D , class V , class VI32 >
118+ HWY_INLINE V LoadExpShortRange (D d, V x, VI32 e) {
119+ const VI32 y = ShiftRight<1 >(e);
120+ return Mul (Mul (x, Pow2I (d, y)), Pow2I (d, Sub (e, y)));
121+ }
122+
123+ template <class D , class V , class VI32 >
124+ HWY_INLINE V ExpReduce (D d, V x, VI32 q) {
125+ // kMinusLn2 ~= -ln(2)
126+ const V kMinusLn2 = Set (d, -0.6931471805599453 );
127+
128+ // Extended precision modular arithmetic.
129+ const V qf = PromoteTo (d, q);
130+ return MulAdd (qf, kMinusLn2 , x);
131+ }
132+ };
133+ #endif
134+
135+ } // namespace impl
136+
59137/* *
60138 * Fast approximation of tan(x).
61139 *
@@ -778,6 +856,69 @@ HWY_INLINE V FastLog(D d, V x) {
778856 return MulAdd (exp, kLn2 , approx);
779857}
780858
859+ /* *
860+ * Fast approximation of exp(x).
861+ *
862+ * Valid Lane Types: float32, float64
863+ * Max ULP Error: 1 for float32 [-FLT_MAX, -87]
864+ * Max ULP Error: 1 for float64 [-DBL_MAX, -708]
865+ * Max Relative Error: 0.06% for float32 [-87, 88]
866+ * Max Relative Error: 0.06% for float64 [-708, 706]
867+ * Average Relative Error: 0.005% for float32 [-87, 88]
868+ * Average Relative Error: 0.006% for float64 [-708, 706]
869+ * Valid Range: float32[-FLT_MAX, +88], float64[-DBL_MAX, +706]
870+ *
871+ * @return e^x
872+ */
873+ template <class D , class V >
874+ HWY_INLINE V FastExp (D d, V x) {
875+ using T = TFromD<D>;
876+ impl::FastExpImpl<T> impl;
877+
878+ const V kHalf = Set (d, static_cast <T>(+0.5 ));
879+ const V kLowerBound =
880+ Set (d, static_cast <T>((sizeof (T) == 4 ? -104.0 : -1000.0 )));
881+ const V kNegZero = Set (d, static_cast <T>(-0.0 ));
882+
883+ const V kOneOverLog2 = Set (d, static_cast <T>(+1.442695040888963407359924681 ));
884+
885+ using TI = MakeSigned<T>;
886+ const Rebind<TI, D> di;
887+ const auto rounded_offs = BitCast (
888+ d, OrAnd (BitCast (di, kHalf ), BitCast (di, x), BitCast (di, kNegZero )));
889+
890+ const auto q = impl.ToInt32 (d, MulAdd (x, kOneOverLog2 , rounded_offs));
891+
892+ // Reduce
893+ const auto x_red = impl.ExpReduce (d, x, q);
894+
895+ // New logic:
896+ // x_in = |x_red| / 2 -> absorbed into coefficients
897+ // if x_red < 0: swap num/den
898+
899+ auto y = Abs (x_red);
900+
901+ const auto a = Set (d, static_cast <T>(-1757.05 ));
902+ const auto b = Set (d, static_cast <T>(-3128.2 ));
903+ const auto c = Set (d, static_cast <T>(1406.95 ));
904+ const auto d_coef = Set (d, static_cast <T>(-3130.2 ));
905+
906+ // res = (Ay + B) / (Cy + D)
907+ auto num = MulAdd (a, y, b);
908+ auto den = MulAdd (c, y, d_coef);
909+
910+ // If x_red < 0, swap num/den
911+ auto final_num = IfNegativeThenElse (x_red, den, num);
912+ auto final_den = IfNegativeThenElse (x_red, num, den);
913+
914+ auto approx = Div (final_num, final_den);
915+
916+ const V res = impl.LoadExpShortRange (d, approx, q);
917+
918+ // Handle underflow
919+ return IfThenElseZero (Ge (x, kLowerBound ), res);
920+ }
921+
781922template <class D , class V >
782923HWY_NOINLINE V CallFastAtan (const D d, VecArg<V> x) {
783924 return FastAtan (d, x);
@@ -803,6 +944,11 @@ HWY_NOINLINE V CallFastLog(const D d, VecArg<V> x) {
803944 return FastLog (d, x);
804945}
805946
947+ template <class D , class V >
948+ HWY_NOINLINE V CallFastExp (const D d, VecArg<V> x) {
949+ return FastExp (d, x);
950+ }
951+
806952} // namespace HWY_NAMESPACE
807953} // namespace hwy
808954HWY_AFTER_NAMESPACE ();
0 commit comments