Skip to content

Commit 2b8f016

Browse files
Nikhil0250copybara-github
authored andcommitted
Add FastLog() function to Highway math library.
Tests are added in math_test.cc to verify the relative error of FastLog() within a specified threshold. High Level Benchmark results : FastLog() has a maximum relative error of 0.063% for (0, +FLT_MAX] rads (float32) and (0, +DBL_MAX] rads (float64). The average relative error across the valid range is 0.000025%. Latency results : FastLog() is significantly faster than Hwy Log() for float32 : (Notable f32x4(1.46x), f32(1.37x), f32x8(1.16x)). PiperOrigin-RevId: 874206987
1 parent 3734cb1 commit 2b8f016

2 files changed

Lines changed: 347 additions & 28 deletions

File tree

hwy/contrib/math/fast_math-inl.h

Lines changed: 239 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,34 @@ HWY_INLINE V FastTan(D d, V x) {
9292
auto idx_int = ConvertTo(RebindToSigned<D>(), idx_float);
9393

9494
HWY_ALIGN static constexpr T arr_a[] = {
95-
static_cast<T>(630.25357464271012), static_cast<T>(572.95779513082321), static_cast<T>(343.77467707849392),
96-
static_cast<T>(572.95779513082321), static_cast<T>(229.18311805232929), static_cast<T>(57.295779513082323),
95+
static_cast<T>(630.25357464271012), static_cast<T>(572.95779513082321),
96+
static_cast<T>(343.77467707849392), static_cast<T>(572.95779513082321),
97+
static_cast<T>(229.18311805232929), static_cast<T>(57.295779513082323),
9798
static_cast<T>(57.295779513082323), static_cast<T>(57.295779513082323)};
9899

99-
HWY_ALIGN static constexpr T arr_b[] = {
100-
static_cast<T>(0.0000000000000000), static_cast<T>(10.0000000000000000), static_cast<T>(46.0000000000000000),
101-
static_cast<T>(217.00000000000000), static_cast<T>(297.00000000000000), static_cast<T>(542.00000000000000),
102-
static_cast<T>(542.00000000000000), static_cast<T>(542.00000000000000)};
100+
HWY_ALIGN static constexpr T arr_b[] = {static_cast<T>(0.0000000000000000),
101+
static_cast<T>(10.0000000000000000),
102+
static_cast<T>(46.0000000000000000),
103+
static_cast<T>(217.00000000000000),
104+
static_cast<T>(297.00000000000000),
105+
static_cast<T>(542.00000000000000),
106+
static_cast<T>(542.00000000000000),
107+
static_cast<T>(542.00000000000000)};
103108

104109
HWY_ALIGN static constexpr T arr_c[] = {
105-
static_cast<T>(-57.295779513082323), static_cast<T>(-229.18311805232929), static_cast<T>(-286.47889756541161),
106-
static_cast<T>(-744.84513367007019), static_cast<T>(-572.95779513082321), static_cast<T>(-630.25357464271012),
107-
static_cast<T>(-630.25357464271012), static_cast<T>(-630.25357464271012)};
110+
static_cast<T>(-57.295779513082323),
111+
static_cast<T>(-229.18311805232929),
112+
static_cast<T>(-286.47889756541161),
113+
static_cast<T>(-744.84513367007019),
114+
static_cast<T>(-572.95779513082321),
115+
static_cast<T>(-630.25357464271012),
116+
static_cast<T>(-630.25357464271012),
117+
static_cast<T>(-630.25357464271012)};
108118

109119
HWY_ALIGN static constexpr T arr_d[] = {
110-
static_cast<T>(632.00000000000000), static_cast<T>(657.00000000000000), static_cast<T>(541.00000000000000),
111-
static_cast<T>(1252.0000000000000), static_cast<T>(910.00000000000000), static_cast<T>(990.00000000000000),
120+
static_cast<T>(632.00000000000000), static_cast<T>(657.00000000000000),
121+
static_cast<T>(541.00000000000000), static_cast<T>(1252.0000000000000),
122+
static_cast<T>(910.00000000000000), static_cast<T>(990.00000000000000),
112123
static_cast<T>(990.00000000000000), static_cast<T>(990.00000000000000)};
113124

114125
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
@@ -243,20 +254,31 @@ HWY_INLINE V FastAtan(D d, V val) {
243254
idx_i = Add(idx_i, And(VecFromMask(DI(), mask75), one_i));
244255

245256
HWY_ALIGN static constexpr T arr_a[] = {
246-
static_cast<T>(630.25357464271012), static_cast<T>(572.95779513082321), static_cast<T>(343.77467707849392),
247-
static_cast<T>(572.95779513082321), static_cast<T>(229.18311805232929), static_cast<T>(57.295779513082323),
257+
static_cast<T>(630.25357464271012), static_cast<T>(572.95779513082321),
258+
static_cast<T>(343.77467707849392), static_cast<T>(572.95779513082321),
259+
static_cast<T>(229.18311805232929), static_cast<T>(57.295779513082323),
248260
static_cast<T>(57.295779513082323), static_cast<T>(57.295779513082323)};
249-
HWY_ALIGN static constexpr T arr_b[] = {
250-
static_cast<T>(0.0000000000000000), static_cast<T>(10.0000000000000000), static_cast<T>(46.0000000000000000),
251-
static_cast<T>(217.00000000000000), static_cast<T>(297.00000000000000), static_cast<T>(542.00000000000000),
252-
static_cast<T>(542.00000000000000), static_cast<T>(542.00000000000000)};
261+
HWY_ALIGN static constexpr T arr_b[] = {static_cast<T>(0.0000000000000000),
262+
static_cast<T>(10.0000000000000000),
263+
static_cast<T>(46.0000000000000000),
264+
static_cast<T>(217.00000000000000),
265+
static_cast<T>(297.00000000000000),
266+
static_cast<T>(542.00000000000000),
267+
static_cast<T>(542.00000000000000),
268+
static_cast<T>(542.00000000000000)};
253269
HWY_ALIGN static constexpr T arr_c[] = {
254-
static_cast<T>(-57.295779513082323), static_cast<T>(-229.18311805232929), static_cast<T>(-286.47889756541161),
255-
static_cast<T>(-744.84513367007019), static_cast<T>(-572.95779513082321), static_cast<T>(-630.25357464271012),
256-
static_cast<T>(-630.25357464271012), static_cast<T>(-630.25357464271012)};
270+
static_cast<T>(-57.295779513082323),
271+
static_cast<T>(-229.18311805232929),
272+
static_cast<T>(-286.47889756541161),
273+
static_cast<T>(-744.84513367007019),
274+
static_cast<T>(-572.95779513082321),
275+
static_cast<T>(-630.25357464271012),
276+
static_cast<T>(-630.25357464271012),
277+
static_cast<T>(-630.25357464271012)};
257278
HWY_ALIGN static constexpr T arr_d[] = {
258-
static_cast<T>(632.00000000000000), static_cast<T>(657.00000000000000), static_cast<T>(541.00000000000000),
259-
static_cast<T>(1252.0000000000000), static_cast<T>(910.00000000000000), static_cast<T>(990.00000000000000),
279+
static_cast<T>(632.00000000000000), static_cast<T>(657.00000000000000),
280+
static_cast<T>(541.00000000000000), static_cast<T>(1252.0000000000000),
281+
static_cast<T>(910.00000000000000), static_cast<T>(990.00000000000000),
260282
static_cast<T>(990.00000000000000), static_cast<T>(990.00000000000000)};
261283

262284
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
@@ -464,21 +486,21 @@ HWY_INLINE V FastTanh(D d, V val) {
464486
idx_i = Min(idx_i, Set(DI(), 7));
465487

466488
HWY_ALIGN static constexpr T arr_a[] = {
467-
static_cast<T>(-7070.3), static_cast<T>(-287.1719),
468-
static_cast<T>(-38.3758), static_cast<T>(-12.0230),
469-
static_cast<T>(-4.4597), static_cast<T>(-2.0653),
470-
static_cast<T>(-1.0094), static_cast<T>(-0.4179)};
489+
static_cast<T>(-7070.3), static_cast<T>(-287.1719),
490+
static_cast<T>(-38.3758), static_cast<T>(-12.0230),
491+
static_cast<T>(-4.4597), static_cast<T>(-2.0653),
492+
static_cast<T>(-1.0094), static_cast<T>(-0.4179)};
471493
HWY_ALIGN static constexpr T arr_b[] = {
472494
static_cast<T>(1.0), static_cast<T>(1.0), static_cast<T>(1.0),
473495
static_cast<T>(1.0), static_cast<T>(1.0), static_cast<T>(1.0),
474496
static_cast<T>(1.0), static_cast<T>(1.0)};
475497
HWY_ALIGN static constexpr T arr_c[] = {
476-
static_cast<T>(-578.0), static_cast<T>(-67.0176),
498+
static_cast<T>(-578.0), static_cast<T>(-67.0176),
477499
static_cast<T>(-16.0803), static_cast<T>(-7.0634),
478500
static_cast<T>(-3.3816), static_cast<T>(-1.8164),
479501
static_cast<T>(-0.9760), static_cast<T>(-0.4175)};
480502
HWY_ALIGN static constexpr T arr_d[] = {
481-
static_cast<T>(-7027.2), static_cast<T>(-272.3521),
503+
static_cast<T>(-7027.2), static_cast<T>(-272.3521),
482504
static_cast<T>(-31.3271), static_cast<T>(-7.3286),
483505
static_cast<T>(-1.1620), static_cast<T>(0.4063),
484506
static_cast<T>(0.8946), static_cast<T>(0.9978)};
@@ -572,6 +594,190 @@ HWY_INLINE V FastTanh(D d, V val) {
572594
return CopySign(result, val); // Restore sign
573595
}
574596

597+
/**
598+
* Fast approximation of log(x).
599+
*
600+
* Valid Lane Types: float32, float64
601+
* Max Relative Error: 0.063%
602+
* Average Relative Error : 0.000025%
603+
* Valid Range: float32: (0, +FLT_MAX]
604+
* float64: (0, +DBL_MAX]
605+
*
606+
* @return natural logarithm of 'x'
607+
*/
608+
template <class D, class V>
609+
HWY_INLINE V FastLog(D d, V x) {
610+
using T = TFromD<D>;
611+
using TI = MakeSigned<T>;
612+
using TU = MakeUnsigned<T>;
613+
const Rebind<TI, D> di;
614+
const Rebind<TU, D> du;
615+
using VI = decltype(Zero(di));
616+
617+
constexpr bool kIsF32 = (sizeof(T) == 4);
618+
619+
// Constants for Range Reduction
620+
// kMagic is approx 1/sqrt(2). It is used to center the mantissa interval
621+
// around 1.0 (specifically [0.707, 1.414])
622+
const VI kMagic = Set(di, kIsF32 ? static_cast<TI>(0x3F3504F3L)
623+
: static_cast<TI>(0x3FE6A09E00000000LL));
624+
// Bit pattern for 1.0. Used in the integer arithmetic to extract the
625+
// exponent.
626+
const VI kExpMask = Set(di, kIsF32 ? static_cast<TI>(0x3F800000L)
627+
: static_cast<TI>(0x3FF0000000000000LL));
628+
// Integer exponent adjustment (-25 or -54) corresponding to kScale.
629+
const VI kExpScale =
630+
Set(di, kIsF32 ? static_cast<TI>(-25) : static_cast<TI>(-54));
631+
// Mantissa mask.
632+
const VI kManMask = Set(di, kIsF32 ? static_cast<TI>(0x7FFFFFL)
633+
: static_cast<TI>(0xFFFFF00000000LL));
634+
// Mask for lower 32 or 64 bits.
635+
const VI kLowerBits = Set(di, kIsF32 ? static_cast<TI>(0x00000000L)
636+
: static_cast<TI>(0xFFFFFFFFLL));
637+
const V kMinNormal = Set(d, kIsF32 ? static_cast<T>(1.175494351e-38f)
638+
: static_cast<T>(2.2250738585072014e-308));
639+
// Scale to normalize subnormal inputs: 2^25 (f32) or 2^54 (f64)
640+
const V kScale = Set(d, kIsF32 ? static_cast<T>(3.355443200e+7f)
641+
: static_cast<T>(1.8014398509481984e+16));
642+
const V kLn2 = Set(d, static_cast<T>(0.6931471805599453));
643+
644+
// Handle Subnormals
645+
const auto is_denormal = Lt(x, kMinNormal);
646+
x = MaskedMulOr(x, is_denormal, x, kScale);
647+
648+
// Compute exponent
649+
auto exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic));
650+
const VI exp_scale =
651+
BitCast(di, IfThenElseZero(is_denormal, BitCast(d, kExpScale)));
652+
653+
constexpr int kMantissaShift = kIsF32 ? 23 : 52;
654+
const auto kBias = Set(di, kIsF32 ? 0x7F : 0x3FF);
655+
const auto exp_int = Sub(BitCast(di, ShiftRight<kMantissaShift>(
656+
BitCast(du, BitCast(d, exp_bits)))),
657+
kBias);
658+
const auto exp = ConvertTo(d, Add(exp_scale, exp_int));
659+
660+
// Renormalize x to y in [0.707, 1.414]
661+
const auto x_bits = BitCast(di, x);
662+
const auto y_bits =
663+
OrAnd(Add(And(exp_bits, kManMask), kMagic), x_bits, kLowerBits);
664+
const V y = BitCast(d, y_bits);
665+
666+
// Polynomial Approximation
667+
const auto t0 = Set(d, static_cast<T>(0.7954951275));
668+
const auto t1 = Set(d, static_cast<T>(0.883883475));
669+
const auto t2 = Set(d, static_cast<T>(0.9722718225));
670+
const auto t3 = Set(d, static_cast<T>(1.06066017));
671+
const auto t4 = Set(d, static_cast<T>(1.1490485175));
672+
const auto t5 = Set(d, static_cast<T>(1.237436865));
673+
const auto t6 = Set(d, static_cast<T>(1.3258252125));
674+
675+
constexpr size_t kLanes = HWY_MAX_LANES_D(D);
676+
V a, c, d_coef;
677+
678+
if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
679+
(HWY_HAVE_SCALABLE && sizeof(T) == 4)) {
680+
// --- Table Lookup ---
681+
const auto scale = Set(d, static_cast<T>(11.3137085));
682+
// Input is always non-negative, so Floor() + ConvertTo()
683+
// can be replaced by direct ConvertTo() (truncation), which is faster.
684+
// We use MulAdd(y, scale, -8.0) instead of Mul(Sub(y, lower_bound), scale)
685+
// to save instructions. 0.70710678 * 11.3137085 ~= 8.0.
686+
auto idx_i = ConvertInRangeTo(
687+
RebindToSigned<D>(), MulAdd(y, scale, Set(d, static_cast<T>(-8.0))));
688+
689+
// Clamp index to 7 to handle overshoots
690+
idx_i = Min(idx_i, Set(RebindToSigned<D>(), 7));
691+
692+
HWY_ALIGN static constexpr T arr_a[] = {
693+
static_cast<T>(-0.9981), static_cast<T>(-0.9996),
694+
static_cast<T>(-1.0000), static_cast<T>(-1.0000),
695+
static_cast<T>(-1.0001), static_cast<T>(-1.0004),
696+
static_cast<T>(-1.0013), static_cast<T>(-1.0026)};
697+
// b array is not needed since b is always 1.0.
698+
HWY_ALIGN static constexpr T arr_c[] = {
699+
static_cast<T>(-0.5825), static_cast<T>(-0.5478),
700+
static_cast<T>(-0.5181), static_cast<T>(-0.4974),
701+
static_cast<T>(-0.4763), static_cast<T>(-0.4597),
702+
static_cast<T>(-0.4454), static_cast<T>(-0.4332)};
703+
HWY_ALIGN static constexpr T arr_d[] = {
704+
static_cast<T>(-0.4371), static_cast<T>(-0.4595),
705+
static_cast<T>(-0.4829), static_cast<T>(-0.5025),
706+
static_cast<T>(-0.5260), static_cast<T>(-0.5482),
707+
static_cast<T>(-0.5706), static_cast<T>(-0.5932)};
708+
709+
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
710+
auto idx = IndicesFromVec(d, idx_i);
711+
a = TableLookupLanes(Load(d, arr_a), idx);
712+
c = TableLookupLanes(Load(d, arr_c), idx);
713+
d_coef = TableLookupLanes(Load(d, arr_d), idx);
714+
} else {
715+
auto idx = IndicesFromVec(d, idx_i);
716+
FixedTag<T, 4> d4;
717+
a = TwoTablesLookupLanes(d, Load(d4, arr_a), Load(d4, arr_a + 4), idx);
718+
c = TwoTablesLookupLanes(d, Load(d4, arr_c), Load(d4, arr_c + 4), idx);
719+
d_coef =
720+
TwoTablesLookupLanes(d, Load(d4, arr_d), Load(d4, arr_d + 4), idx);
721+
}
722+
} else {
723+
// --- FALLBACK PATH: Blend Chain ---
724+
// Start with highest index (7)
725+
a = Set(d, static_cast<T>(-1.0026));
726+
c = Set(d, static_cast<T>(-0.4332));
727+
d_coef = Set(d, static_cast<T>(-0.5932));
728+
729+
// If y < t6 (idx 6)
730+
auto mask = Lt(y, t6);
731+
a = IfThenElse(mask, Set(d, static_cast<T>(-1.0013)), a);
732+
c = IfThenElse(mask, Set(d, static_cast<T>(-0.4454)), c);
733+
d_coef = IfThenElse(mask, Set(d, static_cast<T>(-0.5706)), d_coef);
734+
735+
// If y < t5 (idx 5)
736+
mask = Lt(y, t5);
737+
a = IfThenElse(mask, Set(d, static_cast<T>(-1.0004)), a);
738+
c = IfThenElse(mask, Set(d, static_cast<T>(-0.4597)), c);
739+
d_coef = IfThenElse(mask, Set(d, static_cast<T>(-0.5482)), d_coef);
740+
741+
// If y < t4 (idx 4)
742+
mask = Lt(y, t4);
743+
a = IfThenElse(mask, Set(d, static_cast<T>(-1.0001)), a);
744+
c = IfThenElse(mask, Set(d, static_cast<T>(-0.4763)), c);
745+
d_coef = IfThenElse(mask, Set(d, static_cast<T>(-0.5260)), d_coef);
746+
747+
// If y < t3 (idx 3)
748+
mask = Lt(y, t3);
749+
a = IfThenElse(mask, Set(d, static_cast<T>(-1.0000)), a);
750+
c = IfThenElse(mask, Set(d, static_cast<T>(-0.4974)), c);
751+
d_coef = IfThenElse(mask, Set(d, static_cast<T>(-0.5025)), d_coef);
752+
753+
// If y < t2 (idx 2)
754+
mask = Lt(y, t2);
755+
a = IfThenElse(mask, Set(d, static_cast<T>(-1.0000)), a);
756+
c = IfThenElse(mask, Set(d, static_cast<T>(-0.5181)), c);
757+
d_coef = IfThenElse(mask, Set(d, static_cast<T>(-0.4829)), d_coef);
758+
759+
// If y < t1 (idx 1)
760+
mask = Lt(y, t1);
761+
a = IfThenElse(mask, Set(d, static_cast<T>(-0.9996)), a);
762+
c = IfThenElse(mask, Set(d, static_cast<T>(-0.5478)), c);
763+
d_coef = IfThenElse(mask, Set(d, static_cast<T>(-0.4595)), d_coef);
764+
765+
// If y < t0 (idx 0)
766+
mask = Lt(y, t0);
767+
a = IfThenElse(mask, Set(d, static_cast<T>(-0.9981)), a);
768+
c = IfThenElse(mask, Set(d, static_cast<T>(-0.5825)), c);
769+
d_coef = IfThenElse(mask, Set(d, static_cast<T>(-0.4371)), d_coef);
770+
}
771+
772+
// Math: y = (ax + 1.0)/(cx + d_coef)
773+
auto num = MulAdd(a, y, Set(d, static_cast<T>(1.0)));
774+
auto den = MulAdd(c, y, d_coef);
775+
776+
auto approx = Div(num, den);
777+
778+
return MulAdd(exp, kLn2, approx);
779+
}
780+
575781
template <class D, class V>
576782
HWY_NOINLINE V CallFastAtan(const D d, VecArg<V> x) {
577783
return FastAtan(d, x);
@@ -592,6 +798,11 @@ HWY_NOINLINE V CallFastTanh(const D d, VecArg<V> x) {
592798
return FastTanh(d, x);
593799
}
594800

801+
template <class D, class V>
802+
HWY_NOINLINE V CallFastLog(const D d, VecArg<V> x) {
803+
return FastLog(d, x);
804+
}
805+
595806
} // namespace HWY_NAMESPACE
596807
} // namespace hwy
597808
HWY_AFTER_NAMESPACE();

0 commit comments

Comments
 (0)