Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions hwy/contrib/math/fast_math-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,84 @@ HWY_INLINE void ReduceAngleTan(D d, V ang, V& x_red, V& sign) {

} // namespace impl

namespace impl {

template <class T>
struct FastExpImpl {};

template <>
struct FastExpImpl<float> {
// Rounds float toward zero and returns as int32_t.
template <class D, class V>
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
return ConvertInRangeTo(Rebind<int32_t, D>(), x);
}

// Computes 2^x, where x is an integer.
template <class D, class VI32>
HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
const Rebind<int32_t, D> di32;
const VI32 kOffset = Set(di32, 0x7F);
return BitCast(d, ShiftLeft<23>(Add(x, kOffset)));
}

// Sets the exponent of 'x' to 2^e.
template <class D, class V, class VI32>
HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
const VI32 y = ShiftRight<1>(e);
return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
}

template <class D, class V, class VI32>
HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
// kMinusLn2 ~= -ln(2)
const V kMinusLn2 = Set(d, -0.69314718056f);

// Extended precision modular arithmetic.
const V qf = ConvertTo(d, q);
return MulAdd(qf, kMinusLn2, x);
}
};

#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64
template <>
struct FastExpImpl<double> {
// Rounds double toward zero and returns as int32_t.
template <class D, class V>
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
return DemoteInRangeTo(Rebind<int32_t, D>(), x);
}

// Computes 2^x, where x is an integer.
template <class D, class VI32>
HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
const Rebind<int32_t, D> di32;
const Rebind<int64_t, D> di64;
const VI32 kOffset = Set(di32, 0x3FF);
return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset))));
}

// Sets the exponent of 'x' to 2^e.
template <class D, class V, class VI32>
HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
const VI32 y = ShiftRight<1>(e);
return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
}

template <class D, class V, class VI32>
HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
// kMinusLn2 ~= -ln(2)
const V kMinusLn2 = Set(d, -0.6931471805599453);

// Extended precision modular arithmetic.
const V qf = PromoteTo(d, q);
return MulAdd(qf, kMinusLn2, x);
}
};
#endif

} // namespace impl

/**
* Fast approximation of tan(x).
*
Expand Down Expand Up @@ -778,6 +856,69 @@ HWY_INLINE V FastLog(D d, V x) {
return MulAdd(exp, kLn2, approx);
}

/**
* Fast approximation of exp(x).
*
* Valid Lane Types: float32, float64
* Max ULP Error: 1 for float32 [-FLT_MAX, -87]
* Max ULP Error: 1 for float64 [-DBL_MAX, -708]
* Max Relative Error: 0.06% for float32 [-87, 88]
* Max Relative Error: 0.06% for float64 [-708, 706]
* Average Relative Error: 0.05% for float32 [-87, 88]
* Average Relative Error: 0.06% for float64 [-708, 706]
* Valid Range: float32[-FLT_MAX, +88], float64[-DBL_MAX, +706]
*
* @return e^x
*/
template <class D, class V>
HWY_INLINE V FastExp(D d, V x) {
using T = TFromD<D>;
impl::FastExpImpl<T> impl;

const V kHalf = Set(d, static_cast<T>(+0.5));
const V kLowerBound =
Set(d, static_cast<T>((sizeof(T) == 4 ? -104.0 : -1000.0)));
const V kNegZero = Set(d, static_cast<T>(-0.0));

const V kOneOverLog2 = Set(d, static_cast<T>(+1.442695040888963407359924681));

using TI = MakeSigned<T>;
const Rebind<TI, D> di;
const auto rounded_offs = BitCast(
d, OrAnd(BitCast(di, kHalf), BitCast(di, x), BitCast(di, kNegZero)));

const auto q = impl.ToInt32(d, MulAdd(x, kOneOverLog2, rounded_offs));

// Reduce
const auto x_red = impl.ExpReduce(d, x, q);

// New logic:
// x_in = |x_red| / 2 -> absorbed into coefficients
// if x_red < 0: swap num/den

auto y = Abs(x_red);

const auto a = Set(d, static_cast<T>(-1757.05));
const auto b = Set(d, static_cast<T>(-3128.2));
const auto c = Set(d, static_cast<T>(1406.95));
const auto d_coef = Set(d, static_cast<T>(-3130.2));

// res = (Ay + B) / (Cy + D)
auto num = MulAdd(a, y, b);
auto den = MulAdd(c, y, d_coef);

// If x_red < 0, swap num/den
auto final_num = IfNegativeThenElse(x_red, den, num);
auto final_den = IfNegativeThenElse(x_red, num, den);

auto approx = Div(final_num, final_den);

const V res = impl.LoadExpShortRange(d, approx, q);

// Handle underflow
return IfThenElseZero(Ge(x, kLowerBound), res);
}

template <class D, class V>
HWY_NOINLINE V CallFastAtan(const D d, VecArg<V> x) {
return FastAtan(d, x);
Expand All @@ -803,6 +944,11 @@ HWY_NOINLINE V CallFastLog(const D d, VecArg<V> x) {
return FastLog(d, x);
}

template <class D, class V>
HWY_NOINLINE V CallFastExp(const D d, VecArg<V> x) {
return FastExp(d, x);
}

} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
Expand Down
46 changes: 43 additions & 3 deletions hwy/contrib/math/math_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
}

double max_actual_rel_error = 0.0;
double max_error_value = 0.0;
// Emulation is slower, so cannot afford as many.
const UintT kSamplesPerRange =
static_cast<UintT>(AdjustedReps(static_cast<size_t>(samples)));
Expand All @@ -248,7 +249,10 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
double rel = std::abs(static_cast<double>(actual) -
static_cast<double>(expected)) /
std::abs(static_cast<double>(expected));
max_actual_rel_error = HWY_MAX(max_actual_rel_error, rel);
if (ScalarIsNaN(rel) || rel > max_actual_rel_error) {
max_actual_rel_error = rel;
max_error_value = static_cast<double>(value);
}
if (rel > max_relative_error) {
static int print_count = 0;
if (print_count < 10) {
Expand All @@ -263,8 +267,9 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
}
}
}
fprintf(stderr, "%s: %s max_rel_error %E\n",
hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error);
fprintf(stderr, "%s: %s max_rel_error %E at %E\n",
hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error,
max_error_value);
HWY_ASSERT(max_actual_rel_error <= max_relative_error);
}

Expand All @@ -283,6 +288,40 @@ struct TestFastLog {
}
};

struct TestFastExp {
template <class T, class D>
HWY_NOINLINE void operator()(T, D d) {
if (sizeof(T) == 4) {
// Float Normal Range: [-87.0, +88.0]
// exp(-87) ~= 1.6e-38 (just above min normal 1.17e-38)
TestMathRelative<T, D>("FastExpNormal", std::exp, CallFastExp, d,
static_cast<T>(-87.0), static_cast<T>(88.0),
0.0007, 1e7);

// Float Subnormal Range: [-104.0, -87.0]
// exp(-104) is close to 0. Error is dominated by quantization (1 ULP ~=
// 50% relative error for small values).
TestMath<T, D>("FastExpSubnormal", std::exp, CallFastExp, d,
static_cast<T>(-FLT_MAX), static_cast<T>(-87.0), 1);
} else {
// Double Normal Range: [-708.0, +706.0]
// exp(-708) ~= 2.2e-308 (min normal 2.22e-308)
TestMathRelative<T, D>("FastExpNormal", std::exp, CallFastExp, d,
static_cast<T>(-708.0), static_cast<T>(706.0),
0.0007, 1e7);

// Double Subnormal Range: [-744.0, -708.0]
// exp(-744) is very small. Quantization error is expected.
TestMath<T, D>("FastExpSubnormal", std::exp, CallFastExp, d,
static_cast<T>(-DBL_MAX), static_cast<T>(-708.0), 1);
}
}
};

HWY_NOINLINE void TestAllFastExp() {
ForFloat3264Types(ForPartialVectors<TestFastExp>());
}

HWY_NOINLINE void TestAllFastLog() {
ForFloat3264Types(ForPartialVectors<TestFastLog>());
}
Expand All @@ -305,6 +344,7 @@ HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog1p);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog2);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastLog);
HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastExp);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading