Skip to content

Commit 2e45590

Browse files
committed
Make powm1 GPU compatible
1 parent 8fc0e39 commit 2e45590

2 files changed

Lines changed: 43 additions & 29 deletions

File tree

include/boost/math/special_functions/math_fwd.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,11 +591,11 @@ namespace boost
591591

592592
// Power - 1
593593
template <class T1, class T2>
594-
tools::promote_args_t<T1, T2>
594+
BOOST_MATH_GPU_ENABLED tools::promote_args_t<T1, T2>
595595
powm1(const T1 a, const T2 z);
596596

597597
template <class T1, class T2, class Policy>
598-
tools::promote_args_t<T1, T2>
598+
BOOST_MATH_GPU_ENABLED tools::promote_args_t<T1, T2>
599599
powm1(const T1 a, const T2 z, const Policy&);
600600

601601
// sqrt(1+x) - 1
@@ -1481,7 +1481,7 @@ namespace boost
14811481
\
14821482
template <class T1, class T2>\
14831483
inline boost::math::tools::promote_args_t<T1, T2> \
1484-
powm1(const T1 a, const T2 z){ return boost::math::powm1(a, z, Policy()); }\
1484+
BOOST_MATH_GPU_ENABLED powm1(const T1 a, const T2 z){ return boost::math::powm1(a, z, Policy()); }\
14851485
\
14861486
template <class T>\
14871487
BOOST_MATH_GPU_ENABLED inline boost::math::tools::promote_args_t<T> sqrt1pm1(const T& val){ return boost::math::sqrt1pm1(val, Policy()); }\

include/boost/math/special_functions/powm1.hpp

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// (C) Copyright John Maddock 2006.
2+
// (C) Copyright Matt Borland 2024.
23
// Use, modification and distribution are subject to the
34
// Boost Software License, Version 1.0. (See accompanying file
45
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
@@ -12,6 +13,7 @@
1213
#pragma warning(disable:4702) // Unreachable code (release mode only warning)
1314
#endif
1415

16+
#include <boost/math/tools/config.hpp>
1517
#include <boost/math/special_functions/math_fwd.hpp>
1618
#include <boost/math/special_functions/log1p.hpp>
1719
#include <boost/math/special_functions/expm1.hpp>
@@ -22,32 +24,23 @@
2224
namespace boost{ namespace math{ namespace detail{
2325

2426
template <class T, class Policy>
25-
inline T powm1_imp(const T x, const T y, const Policy& pol)
27+
BOOST_MATH_GPU_ENABLED inline T powm1_imp(const T x, const T y, const Policy& pol)
2628
{
2729
BOOST_MATH_STD_USING
28-
static const char* function = "boost::math::powm1<%1%>(%1%, %1%)";
29-
if (x > 0)
30+
constexpr auto function = "boost::math::powm1<%1%>(%1%, %1%)";
31+
32+
if ((fabs(y * (x - 1)) < T(0.5)) || (fabs(y) < T(0.2)))
3033
{
31-
if ((fabs(y * (x - 1)) < T(0.5)) || (fabs(y) < T(0.2)))
32-
{
33-
// We don't have any good/quick approximation for log(x) * y
34-
// so just try it and see:
35-
T l = y * log(x);
36-
if (l < T(0.5))
37-
return boost::math::expm1(l, pol);
38-
if (l > boost::math::tools::log_max_value<T>())
39-
return boost::math::policies::raise_overflow_error<T>(function, nullptr, pol);
40-
// fall through....
41-
}
42-
}
43-
else if ((boost::math::signbit)(x)) // Need to error check -0 here as well
44-
{
45-
// y had better be an integer:
46-
if (boost::math::trunc(y) != y)
47-
return boost::math::policies::raise_domain_error<T>(function, "For non-integral exponent, expected base > 0 but got %1%", x, pol);
48-
if (boost::math::trunc(y / 2) == y / 2)
49-
return powm1_imp(T(-x), y, pol);
34+
// We don't have any good/quick approximation for log(x) * y
35+
// so just try it and see:
36+
T l = y * log(x);
37+
if (l < T(0.5))
38+
return boost::math::expm1(l, pol);
39+
if (l > boost::math::tools::log_max_value<T>())
40+
return boost::math::policies::raise_overflow_error<T>(function, nullptr, pol);
41+
// fall through....
5042
}
43+
5144
T result = pow(x, y) - 1;
5245
if((boost::math::isinf)(result))
5346
return result < 0 ? -boost::math::policies::raise_overflow_error<T>(function, nullptr, pol) : boost::math::policies::raise_overflow_error<T>(function, nullptr, pol);
@@ -56,22 +49,43 @@ inline T powm1_imp(const T x, const T y, const Policy& pol)
5649
return result;
5750
}
5851

52+
template <class T, class Policy>
53+
BOOST_MATH_GPU_ENABLED inline T powm1_imp_dispatch(const T x, const T y, const Policy& pol)
54+
{
55+
BOOST_MATH_STD_USING
56+
57+
if ((boost::math::signbit)(x)) // Need to error check -0 here as well
58+
{
59+
constexpr auto function = "boost::math::powm1<%1%>(%1%, %1%)";
60+
61+
// y had better be an integer:
62+
if (boost::math::trunc(y) != y)
63+
return boost::math::policies::raise_domain_error<T>(function, "For non-integral exponent, expected base > 0 but got %1%", x, pol);
64+
if (boost::math::trunc(y / 2) == y / 2)
65+
return powm1_imp(T(-x), T(y), pol);
66+
}
67+
else
68+
{
69+
return powm1_imp(T(x), T(y), pol);
70+
}
71+
}
72+
5973
} // detail
6074

6175
template <class T1, class T2>
62-
inline typename tools::promote_args<T1, T2>::type
76+
BOOST_MATH_GPU_ENABLED inline typename tools::promote_args<T1, T2>::type
6377
powm1(const T1 a, const T2 z)
6478
{
6579
typedef typename tools::promote_args<T1, T2>::type result_type;
66-
return detail::powm1_imp(static_cast<result_type>(a), static_cast<result_type>(z), policies::policy<>());
80+
return detail::powm1_imp_dispatch(static_cast<result_type>(a), static_cast<result_type>(z), policies::policy<>());
6781
}
6882

6983
template <class T1, class T2, class Policy>
70-
inline typename tools::promote_args<T1, T2>::type
84+
BOOST_MATH_GPU_ENABLED inline typename tools::promote_args<T1, T2>::type
7185
powm1(const T1 a, const T2 z, const Policy& pol)
7286
{
7387
typedef typename tools::promote_args<T1, T2>::type result_type;
74-
return detail::powm1_imp(static_cast<result_type>(a), static_cast<result_type>(z), pol);
88+
return detail::powm1_imp_dispatch(static_cast<result_type>(a), static_cast<result_type>(z), pol);
7589
}
7690

7791
} // namespace math

0 commit comments

Comments
 (0)