Skip to content

Commit cf466cc

Browse files
Nikhil0250copybara-github
authored andcommitted
Improve FastLog()
PiperOrigin-RevId: 909921001
1 parent af486ba commit cf466cc

2 files changed

Lines changed: 134 additions & 98 deletions

File tree

hwy/contrib/math/fast_math-inl.h

Lines changed: 133 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -904,8 +904,8 @@ HWY_INLINE V FastTanh(D d, V val) {
904904
* Fast approximation of log(x).
905905
*
906906
* Valid Lane Types: float32, float64
907-
* Max Relative Error: 0.0095%
908-
* Average Relative Error : 0.000014%
907+
* Max Relative Error: 0.0012%
908+
* Average Relative Error: 3e-6% for float32, 1.8e-7% for float64
909909
* Valid Range: float32: (0, +FLT_MAX]
910910
* float64: (0, +DBL_MAX]
911911
*
@@ -920,7 +920,13 @@ HWY_INLINE V FastLog(D d, V x) {
920920
impl::FastLogRangeReduction<kHandleSubnormals>(d, x, y, exp);
921921

922922
constexpr size_t kLanes = HWY_MAX_LANES_D(D);
923-
V b, c, d_coef;
923+
V approx;
924+
925+
V a, b, c, d_val;
926+
// Centering the approximation around y=1.0 by using z = y - 1.0 significantly
927+
// improves accuracy for low-degree polynomials compared to approximating
928+
// log(y) directly.
929+
const V z = Sub(y, Set(d, static_cast<T>(1.0)));
924930

925931
if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
926932
(HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) {
@@ -936,45 +942,55 @@ HWY_INLINE V FastLog(D d, V x) {
936942
// Clamp index to 7 to handle overshoots
937943
idx_i = Min(idx_i, Set(RebindToSigned<D>(), 7));
938944

945+
HWY_ALIGN static constexpr T arr_a[8] = {
946+
static_cast<T>(0.78766119873962426),
947+
static_cast<T>(0.56395605885234767),
948+
static_cast<T>(0.41755823888409732),
949+
static_cast<T>(0.31775546220809975),
950+
static_cast<T>(0.24738922014476947),
951+
static_cast<T>(0.19635862241628779),
952+
static_cast<T>(0.15845269802741027),
953+
static_cast<T>(0.12974944454997622)};
954+
939955
HWY_ALIGN static constexpr T arr_b[8] = {
940-
static_cast<T>(-1.00194730895928918),
941-
static_cast<T>(-1.00042661239958708),
942-
static_cast<T>(-1.0000255203465902),
943-
static_cast<T>(-1),
944-
static_cast<T>(-0.999929163668789478),
945-
static_cast<T>(-0.999558969823431065),
946-
static_cast<T>(-0.998743736501089163),
947-
static_cast<T>(-0.997397894886509873)};
956+
static_cast<T>(-0.29967724727628686),
957+
static_cast<T>(-0.43890059639104201),
958+
static_cast<T>(-0.49106265092580692),
959+
static_cast<T>(-0.50008637171949),
960+
static_cast<T>(-0.48774751153444412),
961+
static_cast<T>(-0.46524164863536055),
962+
static_cast<T>(-0.43845622820596808),
963+
static_cast<T>(-0.41055729878598496)};
948964

949965
HWY_ALIGN static constexpr T arr_c[8] = {
950-
static_cast<T>(0.58385589069067223),
951-
static_cast<T>(0.548174514768112076),
952-
static_cast<T>(0.519613079391819999),
953-
static_cast<T>(0.497367242550162236),
954-
static_cast<T>(0.476391677761481835),
955-
static_cast<T>(0.459525070958496262),
956-
static_cast<T>(0.44490172854808846),
957-
static_cast<T>(0.432070989622927948)};
966+
static_cast<T>(1.0358118335702087),
967+
static_cast<T>(1.0067153345685411),
968+
static_cast<T>(1.000379283174812),
969+
static_cast<T>(1.0000110351938951),
970+
static_cast<T>(0.99922180103707492),
971+
static_cast<T>(0.99586383428901692),
972+
static_cast<T>(0.98951797571207256),
973+
static_cast<T>(0.98045123777070986)};
958974

959975
HWY_ALIGN static constexpr T arr_d[8] = {
960-
static_cast<T>(0.437891917978712797),
961-
static_cast<T>(0.459658304416673158),
962-
static_cast<T>(0.481694216614368509),
963-
static_cast<T>(0.502574248959839265),
964-
static_cast<T>(0.525922172040079627),
965-
static_cast<T>(0.547948723977362273),
966-
static_cast<T>(0.569860763464220654),
967-
static_cast<T>(0.591637568597068619)};
976+
static_cast<T>(0.0023082932745966296),
977+
static_cast<T>(0.00026712584767189665),
978+
static_cast<T>(5.4447452042709148e-06),
979+
static_cast<T>(0),
980+
static_cast<T>(1.8158320065986679e-05),
981+
static_cast<T>(0.00018763480353754217),
982+
static_cast<T>(0.0006917031865196592),
983+
static_cast<T>(0.0016769113228540019)};
968984

969985
// Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this
970986
// condition covers all cases we encounter inside the top level if block
971987
// inside FastLog
988+
a = Lookup8(d, arr_a, idx_i);
972989
b = Lookup8(d, arr_b, idx_i);
973990
c = Lookup8(d, arr_c, idx_i);
974-
d_coef = Lookup8(d, arr_d, idx_i);
991+
d_val = Lookup8(d, arr_d, idx_i);
975992
} else {
976993
// --- FALLBACK PATH: Blend Chain ---
977-
// Polynomial Approximation
978994
const auto t0 = Set(d, static_cast<T>(0.7954951287634819));
979995
const auto t1 = Set(d, static_cast<T>(0.8838834764038688));
980996
const auto t2 = Set(d, static_cast<T>(0.9722718240442556));
@@ -985,132 +1001,152 @@ HWY_INLINE V FastLog(D d, V x) {
9851001

9861002
if constexpr (HWY_REGISTERS >= 32) {
9871003
// Split into two parallel chains to reduce dependency latency.
988-
9891004
// -- Chain 1: Indices 0 to 3 (Evaluated starting from t3 down to t0)
990-
auto b_low = Set(d, static_cast<T>(-1)); // idx 3
991-
auto c_low = Set(d, static_cast<T>(0.497367242550162236));
992-
auto d_low = Set(d, static_cast<T>(0.502574248959839265));
1005+
auto a_low = Set(d, static_cast<T>(0.31775546220809975));
1006+
auto b_low = Set(d, static_cast<T>(-0.50008637171949));
1007+
auto c_low = Set(d, static_cast<T>(1.0000110351938951));
1008+
auto d_low = Set(d, static_cast<T>(0));
9931009

9941010
auto mask = Lt(y, t2);
1011+
a_low =
1012+
IfThenElse(mask, Set(d, static_cast<T>(0.41755823888409732)), a_low);
9951013
b_low =
996-
IfThenElse(mask, Set(d, static_cast<T>(-1.0000255203465902)), b_low);
1014+
IfThenElse(mask, Set(d, static_cast<T>(-0.49106265092580692)), b_low);
9971015
c_low =
998-
IfThenElse(mask, Set(d, static_cast<T>(0.519613079391819999)), c_low);
999-
d_low =
1000-
IfThenElse(mask, Set(d, static_cast<T>(0.481694216614368509)), d_low);
1016+
IfThenElse(mask, Set(d, static_cast<T>(1.000379283174812)), c_low);
1017+
d_low = IfThenElse(mask, Set(d, static_cast<T>(5.4447452042709148e-06)),
1018+
d_low);
10011019

10021020
mask = Lt(y, t1);
1021+
a_low =
1022+
IfThenElse(mask, Set(d, static_cast<T>(0.56395605885234767)), a_low);
10031023
b_low =
1004-
IfThenElse(mask, Set(d, static_cast<T>(-1.00042661239958708)), b_low);
1024+
IfThenElse(mask, Set(d, static_cast<T>(-0.43890059639104201)), b_low);
10051025
c_low =
1006-
IfThenElse(mask, Set(d, static_cast<T>(0.548174514768112076)), c_low);
1007-
d_low =
1008-
IfThenElse(mask, Set(d, static_cast<T>(0.459658304416673158)), d_low);
1026+
IfThenElse(mask, Set(d, static_cast<T>(1.0067153345685411)), c_low);
1027+
d_low = IfThenElse(mask, Set(d, static_cast<T>(0.00026712584767189665)),
1028+
d_low);
10091029

10101030
mask = Lt(y, t0);
1031+
a_low =
1032+
IfThenElse(mask, Set(d, static_cast<T>(0.78766119873962426)), a_low);
10111033
b_low =
1012-
IfThenElse(mask, Set(d, static_cast<T>(-1.00194730895928918)), b_low);
1034+
IfThenElse(mask, Set(d, static_cast<T>(-0.29967724727628686)), b_low);
10131035
c_low =
1014-
IfThenElse(mask, Set(d, static_cast<T>(0.58385589069067223)), c_low);
1015-
d_low =
1016-
IfThenElse(mask, Set(d, static_cast<T>(0.437891917978712797)), d_low);
1036+
IfThenElse(mask, Set(d, static_cast<T>(1.0358118335702087)), c_low);
1037+
d_low = IfThenElse(mask, Set(d, static_cast<T>(0.0023082932745966296)),
1038+
d_low);
10171039

10181040
// -- Chain 2: Indices 4 to 7 (Evaluated starting from t6 down to t4)
1019-
auto b_high = Set(d, static_cast<T>(-0.997397894886509873)); // idx 7
1020-
auto c_high = Set(d, static_cast<T>(0.432070989622927948));
1021-
auto d_high = Set(d, static_cast<T>(0.591637568597068619));
1041+
auto a_high = Set(d, static_cast<T>(0.12974944454997622));
1042+
auto b_high = Set(d, static_cast<T>(-0.41055729878598496));
1043+
auto c_high = Set(d, static_cast<T>(0.98045123777070986));
1044+
auto d_high = Set(d, static_cast<T>(0.0016769113228540019));
10221045

10231046
mask = Lt(y, t6);
1024-
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.998743736501089163)),
1047+
a_high =
1048+
IfThenElse(mask, Set(d, static_cast<T>(0.15845269802741027)), a_high);
1049+
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.43845622820596808)),
10251050
b_high);
10261051
c_high =
1027-
IfThenElse(mask, Set(d, static_cast<T>(0.44490172854808846)), c_high);
1028-
d_high = IfThenElse(mask, Set(d, static_cast<T>(0.569860763464220654)),
1052+
IfThenElse(mask, Set(d, static_cast<T>(0.98951797571207256)), c_high);
1053+
d_high = IfThenElse(mask, Set(d, static_cast<T>(0.0006917031865196592)),
10291054
d_high);
10301055

10311056
mask = Lt(y, t5);
1032-
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.999558969823431065)),
1057+
a_high =
1058+
IfThenElse(mask, Set(d, static_cast<T>(0.19635862241628779)), a_high);
1059+
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.46524164863536055)),
10331060
b_high);
1034-
c_high = IfThenElse(mask, Set(d, static_cast<T>(0.459525070958496262)),
1035-
c_high);
1036-
d_high = IfThenElse(mask, Set(d, static_cast<T>(0.547948723977362273)),
1061+
c_high =
1062+
IfThenElse(mask, Set(d, static_cast<T>(0.99586383428901692)), c_high);
1063+
d_high = IfThenElse(mask, Set(d, static_cast<T>(0.00018763480353754217)),
10371064
d_high);
10381065

10391066
mask = Lt(y, t4);
1040-
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.999929163668789478)),
1067+
a_high =
1068+
IfThenElse(mask, Set(d, static_cast<T>(0.24738922014476947)), a_high);
1069+
b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.48774751153444412)),
10411070
b_high);
1042-
c_high = IfThenElse(mask, Set(d, static_cast<T>(0.476391677761481835)),
1043-
c_high);
1044-
d_high = IfThenElse(mask, Set(d, static_cast<T>(0.525922172040079627)),
1071+
c_high =
1072+
IfThenElse(mask, Set(d, static_cast<T>(0.99922180103707492)), c_high);
1073+
d_high = IfThenElse(mask, Set(d, static_cast<T>(1.8158320065986679e-05)),
10451074
d_high);
10461075

10471076
// -- Merge the two chains
10481077
auto merge_mask = Lt(y, t3);
1078+
a = IfThenElse(merge_mask, a_low, a_high);
10491079
b = IfThenElse(merge_mask, b_low, b_high);
10501080
c = IfThenElse(merge_mask, c_low, c_high);
1051-
d_coef = IfThenElse(merge_mask, d_low, d_high);
1081+
d_val = IfThenElse(merge_mask, d_low, d_high);
10521082
} else {
10531083
// Start with highest index (7)
1054-
b = Set(d, static_cast<T>(-0.997397894886509873));
1055-
c = Set(d, static_cast<T>(0.432070989622927948));
1056-
d_coef = Set(d, static_cast<T>(0.591637568597068619));
1084+
a = Set(d, static_cast<T>(0.12974944454997622));
1085+
b = Set(d, static_cast<T>(-0.41055729878598496));
1086+
c = Set(d, static_cast<T>(0.98045123777070986));
1087+
d_val = Set(d, static_cast<T>(0.0016769113228540019));
10571088

10581089
// If y < t6 (idx 6)
10591090
auto mask = Lt(y, t6);
1060-
b = IfThenElse(mask, Set(d, static_cast<T>(-0.998743736501089163)), b);
1061-
c = IfThenElse(mask, Set(d, static_cast<T>(0.44490172854808846)), c);
1062-
d_coef = IfThenElse(mask, Set(d, static_cast<T>(0.569860763464220654)),
1063-
d_coef);
1091+
a = IfThenElse(mask, Set(d, static_cast<T>(0.15845269802741027)), a);
1092+
b = IfThenElse(mask, Set(d, static_cast<T>(-0.43845622820596808)), b);
1093+
c = IfThenElse(mask, Set(d, static_cast<T>(0.98951797571207256)), c);
1094+
d_val = IfThenElse(mask, Set(d, static_cast<T>(0.0006917031865196592)),
1095+
d_val);
10641096

10651097
// If y < t5 (idx 5)
10661098
mask = Lt(y, t5);
1067-
b = IfThenElse(mask, Set(d, static_cast<T>(-0.999558969823431065)), b);
1068-
c = IfThenElse(mask, Set(d, static_cast<T>(0.459525070958496262)), c);
1069-
d_coef = IfThenElse(mask, Set(d, static_cast<T>(0.547948723977362273)),
1070-
d_coef);
1099+
a = IfThenElse(mask, Set(d, static_cast<T>(0.19635862241628779)), a);
1100+
b = IfThenElse(mask, Set(d, static_cast<T>(-0.46524164863536055)), b);
1101+
c = IfThenElse(mask, Set(d, static_cast<T>(0.99586383428901692)), c);
1102+
d_val = IfThenElse(mask, Set(d, static_cast<T>(0.00018763480353754217)),
1103+
d_val);
10711104

10721105
// If y < t4 (idx 4)
10731106
mask = Lt(y, t4);
1074-
b = IfThenElse(mask, Set(d, static_cast<T>(-0.999929163668789478)), b);
1075-
c = IfThenElse(mask, Set(d, static_cast<T>(0.476391677761481835)), c);
1076-
d_coef = IfThenElse(mask, Set(d, static_cast<T>(0.525922172040079627)),
1077-
d_coef);
1107+
a = IfThenElse(mask, Set(d, static_cast<T>(0.24738922014476947)), a);
1108+
b = IfThenElse(mask, Set(d, static_cast<T>(-0.48774751153444412)), b);
1109+
c = IfThenElse(mask, Set(d, static_cast<T>(0.99922180103707492)), c);
1110+
d_val = IfThenElse(mask, Set(d, static_cast<T>(1.8158320065986679e-05)),
1111+
d_val);
10781112

10791113
// If y < t3 (idx 3)
10801114
mask = Lt(y, t3);
1081-
b = IfThenElse(mask, Set(d, static_cast<T>(-1)), b);
1082-
c = IfThenElse(mask, Set(d, static_cast<T>(0.497367242550162236)), c);
1083-
d_coef = IfThenElse(mask, Set(d, static_cast<T>(0.502574248959839265)),
1084-
d_coef);
1115+
a = IfThenElse(mask, Set(d, static_cast<T>(0.31775546220809975)), a);
1116+
b = IfThenElse(mask, Set(d, static_cast<T>(-0.50008637171949)), b);
1117+
c = IfThenElse(mask, Set(d, static_cast<T>(1.0000110351938951)), c);
1118+
d_val = IfThenElse(mask, Set(d, static_cast<T>(0)), d_val);
10851119

10861120
// If y < t2 (idx 2)
10871121
mask = Lt(y, t2);
1088-
b = IfThenElse(mask, Set(d, static_cast<T>(-1.0000255203465902)), b);
1089-
c = IfThenElse(mask, Set(d, static_cast<T>(0.519613079391819999)), c);
1090-
d_coef = IfThenElse(mask, Set(d, static_cast<T>(0.481694216614368509)),
1091-
d_coef);
1122+
a = IfThenElse(mask, Set(d, static_cast<T>(0.41755823888409732)), a);
1123+
b = IfThenElse(mask, Set(d, static_cast<T>(-0.49106265092580692)), b);
1124+
c = IfThenElse(mask, Set(d, static_cast<T>(1.000379283174812)), c);
1125+
d_val = IfThenElse(mask, Set(d, static_cast<T>(5.4447452042709148e-06)),
1126+
d_val);
10921127

10931128
// If y < t1 (idx 1)
10941129
mask = Lt(y, t1);
1095-
b = IfThenElse(mask, Set(d, static_cast<T>(-1.00042661239958708)), b);
1096-
c = IfThenElse(mask, Set(d, static_cast<T>(0.548174514768112076)), c);
1097-
d_coef = IfThenElse(mask, Set(d, static_cast<T>(0.459658304416673158)),
1098-
d_coef);
1130+
a = IfThenElse(mask, Set(d, static_cast<T>(0.56395605885234767)), a);
1131+
b = IfThenElse(mask, Set(d, static_cast<T>(-0.43890059639104201)), b);
1132+
c = IfThenElse(mask, Set(d, static_cast<T>(1.0067153345685411)), c);
1133+
d_val = IfThenElse(mask, Set(d, static_cast<T>(0.00026712584767189665)),
1134+
d_val);
10991135

11001136
// If y < t0 (idx 0)
11011137
mask = Lt(y, t0);
1102-
b = IfThenElse(mask, Set(d, static_cast<T>(-1.00194730895928918)), b);
1103-
c = IfThenElse(mask, Set(d, static_cast<T>(0.58385589069067223)), c);
1104-
d_coef = IfThenElse(mask, Set(d, static_cast<T>(0.437891917978712797)),
1105-
d_coef);
1138+
a = IfThenElse(mask, Set(d, static_cast<T>(0.78766119873962426)), a);
1139+
b = IfThenElse(mask, Set(d, static_cast<T>(-0.29967724727628686)), b);
1140+
c = IfThenElse(mask, Set(d, static_cast<T>(1.0358118335702087)), c);
1141+
d_val = IfThenElse(mask, Set(d, static_cast<T>(0.0023082932745966296)),
1142+
d_val);
11061143
}
11071144
}
1108-
1109-
// Math: y = (x + b)/(cx + d_coef)
1110-
auto num = Add(y, b);
1111-
auto den = MulAdd(c, y, d_coef);
1112-
1113-
auto approx = Div(num, den);
1145+
// Math: approx = (a*z + b)*z^2 + (c*z + d_val)
1146+
const auto z2 = Mul(z, z);
1147+
const auto pab = MulAdd(a, z, b);
1148+
const auto pcd = MulAdd(c, z, d_val);
1149+
approx = MulAdd(pab, z2, pcd);
11141150

11151151
return MulAdd(exp, kLn2, approx);
11161152
}

hwy/contrib/math/math_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
300300
struct TestFastLog {
301301
template <class T, class D>
302302
HWY_NOINLINE void operator()(T, D d) {
303-
const double max_relative_error = 9.5E-5; // SVE: 8.99
303+
const double max_relative_error = 1.15E-5;
304304
const uint64_t samples = 1000000;
305305
if (sizeof(T) == 4) {
306306
TestMathRelative<T, D>("FastLog", std::log, CallFastLog, d,

0 commit comments

Comments
 (0)