Skip to content

Commit a8fd03b

Browse files
committed
Use 50% tighter constraints when no FMA is available to compensate for lost precision. Also test accuracy of non-forced polynomials, i.e., potentially intrinsics.
1 parent 39e3a97 commit a8fd03b

8 files changed

Lines changed: 197 additions & 44 deletions

src/ApproximationTables.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ const Approximation *find_best_approximation(const std::vector<Approximation> &t
168168
std::printf("Looking for min_terms=%d, max_absolute_error=%f\n",
169169
precision.constraint_min_poly_terms, precision.constraint_max_absolute_error);
170170
#endif
171-
constexpr double safety_factor = 1.05;
171+
constexpr double safety_factor = 1.02;
172172
for (size_t i = 0; i < table.size(); ++i) {
173173
const Approximation &e = table[i];
174174

src/FastMathFunctions.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,18 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
316316
// Positive arguments to exp() have preciser ULP.
317317
// So, we will rewrite the expression to always use exp(2*x)
318318
// instead of exp(-2*x) when we are close to zero.
319+
// Rewriting it like this is slighlty more expensive, hence the branch
320+
// to only pay this extra cost in case we need MULPE-optimized approximations.
319321
Expr flip_exp = abs_x > constant(type, 4);
320322
Expr arg_exp = select(flip_exp, -abs_x, abs_x);
321323
Expr exp2x = Halide::fast_exp(2 * arg_exp, prec);
322324
Expr tanh = (exp2x - constant(type, 1.0)) / (exp2x + constant(type, 1));
323325
tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
324326
return common_subexpression_elimination(tanh, true);
325327
} else {
328+
// Even if we are optimizing for MAE, the nested call to exp()
329+
// should be MULPE optimized for accuracy, as we are taking ratios.
330+
prec.optimized_for = ApproximationPrecision::MULPE;
326331
Expr exp2x = Halide::fast_exp(-2 * abs_x, prec);
327332
Expr tanh = (constant(type, 1) - exp2x) / (constant(type, 1) + exp2x);
328333
tanh = select(flip_sign, -tanh, tanh);
@@ -435,6 +440,57 @@ IntrinsicsInfoPerDeviceAPI ii_tanh{
435440
}};
436441
// clang-format on
437442

443+
bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, DeviceAPI device, const Target &t) {
444+
const IntrinsicsInfoPerDeviceAPI *iipda = nullptr;
445+
switch (op) {
446+
case Call::fast_atan:
447+
case Call::fast_atan2:
448+
iipda = &ii_atan_atan2;
449+
break;
450+
case Call::fast_cos:
451+
iipda = &ii_cos;
452+
break;
453+
case Call::fast_exp:
454+
iipda = &ii_exp;
455+
break;
456+
case Call::fast_log:
457+
iipda = &ii_log;
458+
break;
459+
case Call::fast_pow:
460+
iipda = &ii_pow;
461+
break;
462+
case Call::fast_sin:
463+
iipda = &ii_sin;
464+
break;
465+
case Call::fast_tan:
466+
iipda = &ii_tan;
467+
break;
468+
case Call::fast_tanh:
469+
iipda = &ii_tanh;
470+
break;
471+
472+
default:
473+
std::string name = Call::get_intrinsic_name(op);
474+
internal_assert(name.length() > 5 && name.substr(0, 5) != "fast_") << "Did not handle " << name << " in switch case";
475+
break;
476+
}
477+
478+
479+
internal_assert(iipda != nullptr) << "Function is only supported for fast_xxx math functions. Got: " << Call::get_intrinsic_name(op);
480+
481+
for (const auto &cand : iipda->device_apis) {
482+
if (cand.device_api == device) {
483+
if (cand.intrinsic.defined()) {
484+
if (op == Call::fast_tanh && device == DeviceAPI::CUDA) {
485+
return t.get_cuda_capability_lower_bound() >= 75;
486+
}
487+
return true;
488+
}
489+
}
490+
}
491+
return false;
492+
}
493+
438494
IntrinsicsInfo resolve_precision(ApproximationPrecision &prec, const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
439495
IntrinsicsInfo ii{};
440496
for (const auto &cand : iida.device_apis) {
@@ -562,6 +618,18 @@ class LowerFastMathFunctions : public IRMutator {
562618
return for_device_api == DeviceAPI::CUDA && target.get_cuda_capability_lower_bound() >= 75;
563619
}
564620

621+
void adjust_precision_for_target(ApproximationPrecision &prec) {
622+
if (for_device_api == DeviceAPI::None) {
623+
if (target.arch == Target::Arch::X86) {
624+
// If we do not have fused-multiply-add, we lose some precision.
625+
if (target.bits == 32 || !target.has_feature(Target::Feature::FMA)) {
626+
prec.constraint_max_absolute_error *= 0.5f;
627+
prec.constraint_max_ulp_error /= 2;
628+
}
629+
}
630+
}
631+
}
632+
565633
/** Strips the fast_ prefix, appends the type suffix, and
566634
* drops the precision argument from the end. */
567635
Expr to_native_func(const Call *op) {
@@ -652,6 +720,7 @@ class LowerFastMathFunctions : public IRMutator {
652720
}
653721

654722
// No known fast version available, we will expand our own approximation.
723+
adjust_precision_for_target(prec);
655724
return ApproxImpl::fast_sin(mutate(op->args[0]), prec);
656725
} else if (op->is_intrinsic(Call::fast_cos)) {
657726
ApproximationPrecision prec = extract_approximation_precision(op);
@@ -664,6 +733,7 @@ class LowerFastMathFunctions : public IRMutator {
664733
}
665734

666735
// No known fast version available, we will expand our own approximation.
736+
adjust_precision_for_target(prec);
667737
return ApproxImpl::fast_cos(mutate(op->args[0]), prec);
668738
} else if (op->is_intrinsic(Call::fast_atan) || op->is_intrinsic(Call::fast_atan2)) {
669739
// Handle fast_atan and fast_atan2 together!
@@ -673,6 +743,8 @@ class LowerFastMathFunctions : public IRMutator {
673743
// The native atan is fast: fall back to native and continue lowering.
674744
return to_native_func(op);
675745
}
746+
747+
adjust_precision_for_target(prec);
676748
if (op->is_intrinsic(Call::fast_atan)) {
677749
return ApproxImpl::fast_atan(mutate(op->args[0]), prec);
678750
} else {
@@ -696,6 +768,8 @@ class LowerFastMathFunctions : public IRMutator {
696768
// The native atan is fast: fall back to native and continue lowering.
697769
return to_native_func(op);
698770
}
771+
772+
adjust_precision_for_target(prec);
699773
return ApproxImpl::fast_tan(mutate(op->args[0]), prec);
700774
} else if (op->is_intrinsic(Call::fast_exp)) {
701775
// Handle fast_exp and fast_log together!
@@ -718,6 +792,8 @@ class LowerFastMathFunctions : public IRMutator {
718792
// The native atan is fast: fall back to native and continue lowering.
719793
return to_native_func(op);
720794
}
795+
796+
adjust_precision_for_target(prec);
721797
return ApproxImpl::fast_exp(mutate(op->args[0]), prec);
722798
} else if (op->is_intrinsic(Call::fast_log)) {
723799
// Handle fast_exp and fast_log together!
@@ -738,6 +814,8 @@ class LowerFastMathFunctions : public IRMutator {
738814
// The native atan is fast: fall back to native and continue lowering.
739815
return to_native_func(op);
740816
}
817+
818+
adjust_precision_for_target(prec);
741819
return ApproxImpl::fast_log(mutate(op->args[0]), prec);
742820
} else if (op->is_intrinsic(Call::fast_tanh)) {
743821
ApproximationPrecision prec = extract_approximation_precision(op);
@@ -748,6 +826,7 @@ class LowerFastMathFunctions : public IRMutator {
748826
}
749827

750828
// Expand using defintion in terms of exp(2x), and recurse.
829+
// Note: no adjustment of precision, as the recursed mutation will take care of that!
751830
return mutate(ApproxImpl::fast_tanh(op->args[0], prec));
752831
} else if (op->is_intrinsic(Call::fast_pow)) {
753832
ApproximationPrecision prec = extract_approximation_precision(op);

src/FastMathFunctions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
#define HALIDE_INTERNAL_FAST_MATH_H
33

44
#include "Expr.h"
5+
#include "IR.h"
56

67
namespace Halide {
78
namespace Internal {
89

10+
bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, DeviceAPI device, const Target &t);
11+
912
Stmt lower_fast_math_functions(const Stmt &s, const Target &t);
1013

1114
}

src/IROperator.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ Expr fast_pow(const Expr &x, const Expr &y, ApproximationPrecision prec) {
13831383
if (auto i = as_const_int(y)) {
13841384
return raise_to_integer_power(x, *i);
13851385
}
1386+
user_assert(x.type() == Float(32) && y.type() == Float(32)) << "fast_exp only works for Float(32)";
13861387
return Call::make(x.type(), Call::fast_pow, {x, y, make_approximation_precision_info(prec)}, Call::PureIntrinsic);
13871388
}
13881389

src/IROperator.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,9 +1073,11 @@ struct ApproximationPrecision {
10731073
* See \ref ApproximationPrecision for details on specifying precision.
10741074
*/
10751075
// @{
1076-
//* On NVIDIA CUDA: default-precision maps to a dedicated sin.approx.f32 instruction. */
1076+
/** Caution: Might exceed the range (-1, 1) by a tiny bit.
1077+
* On NVIDIA CUDA: default-precision maps to a dedicated sin.approx.f32 instruction. */
10771078
Expr fast_sin(const Expr &x, ApproximationPrecision precision = {});
1078-
/** On NVIDIA CUDA: default-precision maps to a dedicated cos.approx.f32 instruction. */
1079+
/** Caution: Might exceed the range (-1, 1) by a tiny bit.
1080+
* On NVIDIA CUDA: default-precision maps to a dedicated cos.approx.f32 instruction. */
10791081
Expr fast_cos(const Expr &x, ApproximationPrecision precision = {});
10801082
/** On NVIDIA CUDA: default-precision maps to a combination of sin.approx.f32,
10811083
* cos.approx.f32, div.approx.f32 instructions. */
@@ -1118,6 +1120,7 @@ Expr fast_pow(const Expr &x, const Expr &y, ApproximationPrecision precision = {
11181120

11191121
/** Fast approximate pow for Float(32).
11201122
* Approximations accurate to 2e-7 MAE, and Max 2500 ULPs (on average < 1 ULP) available.
1123+
* Caution: might exceed the range (-1, 1) by a tiny bit.
11211124
* Vectorizes cleanly when using polynomials.
11221125
* Slow on x86 if you don't have at least sse 4.1.
11231126
* On NVIDIA CUDA: default-precision maps to a combination of ex2.approx.f32 and lg2.approx.f32.

0 commit comments

Comments
 (0)