Skip to content

Commit dfc62fb

Browse files
Nikhil0250copybara-github
authored andcommitted
Add FastGelu activation function in a newly created created fast_ops-inl.h files.
This replaces the Tanh call with FastTanh call in the Gelu function written in math-inl.h PiperOrigin-RevId: 874783036
1 parent da7098a commit dfc62fb

2 files changed

Lines changed: 189 additions & 3 deletions

File tree

hwy/contrib/math/fast_math-inl.h

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,84 @@ HWY_INLINE void ReduceAngleTan(D d, V ang, V& x_red, V& sign) {
5656

5757
} // namespace impl
5858

59+
namespace impl {
60+
61+
template <class T>
62+
struct FastExpImpl {};
63+
64+
template <>
65+
struct FastExpImpl<float> {
66+
// Rounds float toward zero and returns as int32_t.
67+
template <class D, class V>
68+
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
69+
return ConvertInRangeTo(Rebind<int32_t, D>(), x);
70+
}
71+
72+
// Computes 2^x, where x is an integer.
73+
template <class D, class VI32>
74+
HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
75+
const Rebind<int32_t, D> di32;
76+
const VI32 kOffset = Set(di32, 0x7F);
77+
return BitCast(d, ShiftLeft<23>(Add(x, kOffset)));
78+
}
79+
80+
// Sets the exponent of 'x' to 2^e.
81+
template <class D, class V, class VI32>
82+
HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
83+
const VI32 y = ShiftRight<1>(e);
84+
return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
85+
}
86+
87+
template <class D, class V, class VI32>
88+
HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
89+
// kMinusLn2 ~= -ln(2)
90+
const V kMinusLn2 = Set(d, -0.69314718056f);
91+
92+
// Extended precision modular arithmetic.
93+
const V qf = ConvertTo(d, q);
94+
return MulAdd(qf, kMinusLn2, x);
95+
}
96+
};
97+
98+
#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64
99+
template <>
100+
struct FastExpImpl<double> {
101+
// Rounds double toward zero and returns as int32_t.
102+
template <class D, class V>
103+
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
104+
return DemoteInRangeTo(Rebind<int32_t, D>(), x);
105+
}
106+
107+
// Computes 2^x, where x is an integer.
108+
template <class D, class VI32>
109+
HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
110+
const Rebind<int32_t, D> di32;
111+
const Rebind<int64_t, D> di64;
112+
const VI32 kOffset = Set(di32, 0x3FF);
113+
return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset))));
114+
}
115+
116+
// Sets the exponent of 'x' to 2^e.
117+
template <class D, class V, class VI32>
118+
HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
119+
const VI32 y = ShiftRight<1>(e);
120+
return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
121+
}
122+
123+
template <class D, class V, class VI32>
124+
HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
125+
// kMinusLn2 ~= -ln(2)
126+
const V kMinusLn2 = Set(d, -0.6931471805599453);
127+
128+
// Extended precision modular arithmetic.
129+
const V qf = PromoteTo(d, q);
130+
return MulAdd(qf, kMinusLn2, x);
131+
}
132+
};
133+
#endif
134+
135+
} // namespace impl
136+
59137
/**
60138
* Fast approximation of tan(x).
61139
*
@@ -778,6 +856,69 @@ HWY_INLINE V FastLog(D d, V x) {
778856
return MulAdd(exp, kLn2, approx);
779857
}
780858

859+
/**
860+
* Fast approximation of exp(x).
861+
*
862+
* Valid Lane Types: float32, float64
863+
* Max ULP Error: 1 for float32 [-FLT_MAX, -87]
864+
* Max ULP Error: 1 for float64 [-DBL_MAX, -708]
865+
* Max Relative Error: 0.06% for float32 [-87, 88]
866+
* Max Relative Error: 0.06% for float64 [-708, 706]
867+
* Average Relative Error: 0.005% for float32 [-87, 88]
868+
* Average Relative Error: 0.006% for float64 [-708, 706]
869+
* Valid Range: float32[-FLT_MAX, +88], float64[-DBL_MAX, +706]
870+
*
871+
* @return e^x
872+
*/
873+
template <class D, class V>
874+
HWY_INLINE V FastExp(D d, V x) {
875+
using T = TFromD<D>;
876+
impl::FastExpImpl<T> impl;
877+
878+
const V kHalf = Set(d, static_cast<T>(+0.5));
879+
const V kLowerBound =
880+
Set(d, static_cast<T>((sizeof(T) == 4 ? -104.0 : -1000.0)));
881+
const V kNegZero = Set(d, static_cast<T>(-0.0));
882+
883+
const V kOneOverLog2 = Set(d, static_cast<T>(+1.442695040888963407359924681));
884+
885+
using TI = MakeSigned<T>;
886+
const Rebind<TI, D> di;
887+
const auto rounded_offs = BitCast(
888+
d, OrAnd(BitCast(di, kHalf), BitCast(di, x), BitCast(di, kNegZero)));
889+
890+
const auto q = impl.ToInt32(d, MulAdd(x, kOneOverLog2, rounded_offs));
891+
892+
// Reduce
893+
const auto x_red = impl.ExpReduce(d, x, q);
894+
895+
// New logic:
896+
// x_in = |x_red| / 2 -> absorbed into coefficients
897+
// if x_red < 0: swap num/den
898+
899+
auto y = Abs(x_red);
900+
901+
const auto a = Set(d, static_cast<T>(-1757.05));
902+
const auto b = Set(d, static_cast<T>(-3128.2));
903+
const auto c = Set(d, static_cast<T>(1406.95));
904+
const auto d_coef = Set(d, static_cast<T>(-3130.2));
905+
906+
// res = (Ay + B) / (Cy + D)
907+
auto num = MulAdd(a, y, b);
908+
auto den = MulAdd(c, y, d_coef);
909+
910+
// If x_red < 0, swap num/den
911+
auto final_num = IfNegativeThenElse(x_red, den, num);
912+
auto final_den = IfNegativeThenElse(x_red, num, den);
913+
914+
auto approx = Div(final_num, final_den);
915+
916+
const V res = impl.LoadExpShortRange(d, approx, q);
917+
918+
// Handle underflow
919+
return IfThenElseZero(Ge(x, kLowerBound), res);
920+
}
921+
781922
template <class D, class V>
782923
HWY_NOINLINE V CallFastAtan(const D d, VecArg<V> x) {
783924
return FastAtan(d, x);
@@ -803,6 +944,11 @@ HWY_NOINLINE V CallFastLog(const D d, VecArg<V> x) {
803944
return FastLog(d, x);
804945
}
805946

947+
template <class D, class V>
948+
HWY_NOINLINE V CallFastExp(const D d, VecArg<V> x) {
949+
return FastExp(d, x);
950+
}
951+
806952
} // namespace HWY_NAMESPACE
807953
} // namespace hwy
808954
HWY_AFTER_NAMESPACE();

hwy/contrib/math/math_test.cc

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
222222
}
223223

224224
double max_actual_rel_error = 0.0;
225+
double max_error_value = 0.0;
225226
// Emulation is slower, so cannot afford as many.
226227
const UintT kSamplesPerRange =
227228
static_cast<UintT>(AdjustedReps(static_cast<size_t>(samples)));
@@ -248,7 +249,10 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
248249
double rel = std::abs(static_cast<double>(actual) -
249250
static_cast<double>(expected)) /
250251
std::abs(static_cast<double>(expected));
251-
max_actual_rel_error = HWY_MAX(max_actual_rel_error, rel);
252+
if (ScalarIsNaN(rel) || rel > max_actual_rel_error) {
253+
max_actual_rel_error = rel;
254+
max_error_value = static_cast<double>(value);
255+
}
252256
if (rel > max_relative_error) {
253257
static int print_count = 0;
254258
if (print_count < 10) {
@@ -263,8 +267,9 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
263267
}
264268
}
265269
}
266-
fprintf(stderr, "%s: %s max_rel_error %E\n",
267-
hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error);
270+
fprintf(stderr, "%s: %s max_rel_error %E at %E\n",
271+
hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error,
272+
max_error_value);
268273
HWY_ASSERT(max_actual_rel_error <= max_relative_error);
269274
}
270275

@@ -283,6 +288,40 @@ struct TestFastLog {
283288
}
284289
};
285290

291+
struct TestFastExp {
292+
template <class T, class D>
293+
HWY_NOINLINE void operator()(T, D d) {
294+
if (sizeof(T) == 4) {
295+
// Float Normal Range: [-87.0, +88.0]
296+
// exp(-87) ~= 1.6e-38 (just above min normal 1.17e-38)
297+
TestMathRelative<T, D>("FastExpNormal", std::exp, CallFastExp, d,
298+
static_cast<T>(-87.0), static_cast<T>(88.0),
299+
0.0007, 1e7);
300+
301+
// Float Subnormal Range: [-104.0, -87.0]
302+
// exp(-104) is close to 0. Error is dominated by quantization (1 ULP ~=
303+
// 50% relative error for small values).
304+
TestMath<T, D>("FastExpSubnormal", std::exp, CallFastExp, d,
305+
static_cast<T>(-FLT_MAX), static_cast<T>(-87.0), 1);
306+
} else {
307+
// Double Normal Range: [-708.0, +706.0]
308+
// exp(-708) ~= 2.2e-308 (min normal 2.22e-308)
309+
TestMathRelative<T, D>("FastExpNormal", std::exp, CallFastExp, d,
310+
static_cast<T>(-708.0), static_cast<T>(706.0),
311+
0.0007, 1e7);
312+
313+
// Double Subnormal Range: [-744.0, -708.0]
314+
// exp(-744) is very small. Quantization error is expected.
315+
TestMath<T, D>("FastExpSubnormal", std::exp, CallFastExp, d,
316+
static_cast<T>(-DBL_MAX), static_cast<T>(-708.0), 1);
317+
}
318+
}
319+
};
320+
321+
HWY_NOINLINE void TestAllFastExp() {
322+
ForFloat3264Types(ForPartialVectors<TestFastExp>());
323+
}
324+
286325
HWY_NOINLINE void TestAllFastLog() {
287326
ForFloat3264Types(ForPartialVectors<TestFastLog>());
288327
}
@@ -305,6 +344,7 @@ HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10);
305344
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog1p);
306345
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog2);
307346
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastLog);
347+
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastExp);
308348
HWY_AFTER_TEST();
309349
} // namespace
310350
} // namespace hwy

0 commit comments

Comments
 (0)