@@ -316,13 +316,18 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
316316 // Positive arguments to exp() have preciser ULP.
317317 // So, we will rewrite the expression to always use exp(2*x)
318318 // instead of exp(-2*x) when we are close to zero.
319+ // Rewriting it like this is slighlty more expensive, hence the branch
320+ // to only pay this extra cost in case we need MULPE-optimized approximations.
319321 Expr flip_exp = abs_x > constant (type, 4 );
320322 Expr arg_exp = select (flip_exp, -abs_x, abs_x);
321323 Expr exp2x = Halide::fast_exp (2 * arg_exp, prec);
322324 Expr tanh = (exp2x - constant (type, 1.0 )) / (exp2x + constant (type, 1 ));
323325 tanh = select (flip_exp ^ flip_sign, -tanh, tanh);
324326 return common_subexpression_elimination (tanh, true );
325327 } else {
328+ // Even if we are optimizing for MAE, the nested call to exp()
329+ // should be MULPE optimized for accuracy, as we are taking ratios.
330+ prec.optimized_for = ApproximationPrecision::MULPE ;
326331 Expr exp2x = Halide::fast_exp (-2 * abs_x, prec);
327332 Expr tanh = (constant (type, 1 ) - exp2x) / (constant (type, 1 ) + exp2x);
328333 tanh = select (flip_sign, -tanh, tanh);
@@ -435,6 +440,57 @@ IntrinsicsInfoPerDeviceAPI ii_tanh{
435440}};
436441// clang-format on
437442
443+ bool fast_math_func_has_intrinsic_based_implementation (Call::IntrinsicOp op, DeviceAPI device, const Target &t) {
444+ const IntrinsicsInfoPerDeviceAPI *iipda = nullptr ;
445+ switch (op) {
446+ case Call::fast_atan:
447+ case Call::fast_atan2:
448+ iipda = &ii_atan_atan2;
449+ break ;
450+ case Call::fast_cos:
451+ iipda = &ii_cos;
452+ break ;
453+ case Call::fast_exp:
454+ iipda = &ii_exp;
455+ break ;
456+ case Call::fast_log:
457+ iipda = &ii_log;
458+ break ;
459+ case Call::fast_pow:
460+ iipda = &ii_pow;
461+ break ;
462+ case Call::fast_sin:
463+ iipda = &ii_sin;
464+ break ;
465+ case Call::fast_tan:
466+ iipda = &ii_tan;
467+ break ;
468+ case Call::fast_tanh:
469+ iipda = &ii_tanh;
470+ break ;
471+
472+ default :
473+ std::string name = Call::get_intrinsic_name (op);
474+ internal_assert (name.length () > 5 && name.substr (0 , 5 ) != " fast_" ) << " Did not handle " << name << " in switch case" ;
475+ break ;
476+ }
477+
478+
479+ internal_assert (iipda != nullptr ) << " Function is only supported for fast_xxx math functions. Got: " << Call::get_intrinsic_name (op);
480+
481+ for (const auto &cand : iipda->device_apis ) {
482+ if (cand.device_api == device) {
483+ if (cand.intrinsic .defined ()) {
484+ if (op == Call::fast_tanh && device == DeviceAPI::CUDA ) {
485+ return t.get_cuda_capability_lower_bound () >= 75 ;
486+ }
487+ return true ;
488+ }
489+ }
490+ }
491+ return false ;
492+ }
493+
438494IntrinsicsInfo resolve_precision (ApproximationPrecision &prec, const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
439495 IntrinsicsInfo ii{};
440496 for (const auto &cand : iida.device_apis ) {
@@ -562,6 +618,18 @@ class LowerFastMathFunctions : public IRMutator {
562618 return for_device_api == DeviceAPI::CUDA && target.get_cuda_capability_lower_bound () >= 75 ;
563619 }
564620
621+ void adjust_precision_for_target (ApproximationPrecision &prec) {
622+ if (for_device_api == DeviceAPI::None) {
623+ if (target.arch == Target::Arch::X86 ) {
624+ // If we do not have fused-multiply-add, we lose some precision.
625+ if (target.bits == 32 || !target.has_feature (Target::Feature::FMA )) {
626+ prec.constraint_max_absolute_error *= 0 .5f ;
627+ prec.constraint_max_ulp_error /= 2 ;
628+ }
629+ }
630+ }
631+ }
632+
565633 /* * Strips the fast_ prefix, appends the type suffix, and
566634 * drops the precision argument from the end. */
567635 Expr to_native_func (const Call *op) {
@@ -652,6 +720,7 @@ class LowerFastMathFunctions : public IRMutator {
652720 }
653721
654722 // No known fast version available, we will expand our own approximation.
723+ adjust_precision_for_target (prec);
655724 return ApproxImpl::fast_sin (mutate (op->args [0 ]), prec);
656725 } else if (op->is_intrinsic (Call::fast_cos)) {
657726 ApproximationPrecision prec = extract_approximation_precision (op);
@@ -664,6 +733,7 @@ class LowerFastMathFunctions : public IRMutator {
664733 }
665734
666735 // No known fast version available, we will expand our own approximation.
736+ adjust_precision_for_target (prec);
667737 return ApproxImpl::fast_cos (mutate (op->args [0 ]), prec);
668738 } else if (op->is_intrinsic (Call::fast_atan) || op->is_intrinsic (Call::fast_atan2)) {
669739 // Handle fast_atan and fast_atan2 together!
@@ -673,6 +743,8 @@ class LowerFastMathFunctions : public IRMutator {
673743 // The native atan is fast: fall back to native and continue lowering.
674744 return to_native_func (op);
675745 }
746+
747+ adjust_precision_for_target (prec);
676748 if (op->is_intrinsic (Call::fast_atan)) {
677749 return ApproxImpl::fast_atan (mutate (op->args [0 ]), prec);
678750 } else {
@@ -696,6 +768,8 @@ class LowerFastMathFunctions : public IRMutator {
696768 // The native atan is fast: fall back to native and continue lowering.
697769 return to_native_func (op);
698770 }
771+
772+ adjust_precision_for_target (prec);
699773 return ApproxImpl::fast_tan (mutate (op->args [0 ]), prec);
700774 } else if (op->is_intrinsic (Call::fast_exp)) {
701775 // Handle fast_exp and fast_log together!
@@ -718,6 +792,8 @@ class LowerFastMathFunctions : public IRMutator {
718792 // The native atan is fast: fall back to native and continue lowering.
719793 return to_native_func (op);
720794 }
795+
796+ adjust_precision_for_target (prec);
721797 return ApproxImpl::fast_exp (mutate (op->args [0 ]), prec);
722798 } else if (op->is_intrinsic (Call::fast_log)) {
723799 // Handle fast_exp and fast_log together!
@@ -738,6 +814,8 @@ class LowerFastMathFunctions : public IRMutator {
738814 // The native atan is fast: fall back to native and continue lowering.
739815 return to_native_func (op);
740816 }
817+
818+ adjust_precision_for_target (prec);
741819 return ApproxImpl::fast_log (mutate (op->args [0 ]), prec);
742820 } else if (op->is_intrinsic (Call::fast_tanh)) {
743821 ApproximationPrecision prec = extract_approximation_precision (op);
@@ -748,6 +826,7 @@ class LowerFastMathFunctions : public IRMutator {
748826 }
749827
750828 // Expand using defintion in terms of exp(2x), and recurse.
829+ // Note: no adjustment of precision, as the recursed mutation will take care of that!
751830 return mutate (ApproxImpl::fast_tanh (op->args [0 ], prec));
752831 } else if (op->is_intrinsic (Call::fast_pow)) {
753832 ApproximationPrecision prec = extract_approximation_precision (op);
0 commit comments