Skip to content

Commit 4dedb4f

Browse files
committed
fixes for stan to work with opencl v2025.07.22
1 parent 96c62fe commit 4dedb4f

11 files changed

Lines changed: 411 additions & 214 deletions

File tree

stan/math/opencl/kernel_generator/elt_function_cl.hpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <stan/math/opencl/kernels/device_functions/multiply_log.hpp>
2525
#include <stan/math/opencl/kernels/device_functions/Phi.hpp>
2626
#include <stan/math/opencl/kernels/device_functions/Phi_approx.hpp>
27+
#include <stan/math/opencl/kernels/device_functions/std_normal_lcdf.hpp>
2728
#include <stan/math/opencl/kernels/device_functions/trigamma.hpp>
2829
#include <stan/math/opencl/matrix_cl_view.hpp>
2930
#include <stan/math/opencl/kernel_generator/common_return_scalar.hpp>
@@ -314,6 +315,12 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function)
314315
ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx,
315316
opencl_kernels::inv_logit_device_function,
316317
opencl_kernels::phi_approx_device_function)
318+
ADD_UNARY_FUNCTION_WITH_INCLUDES(
319+
std_normal_lcdf_scaled_impl,
320+
opencl_kernels::std_normal_lcdf_device_function)
321+
ADD_UNARY_FUNCTION_WITH_INCLUDES(
322+
std_normal_lcdf_dscaled_impl,
323+
opencl_kernels::std_normal_lcdf_device_function)
317324
ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_Phi, opencl_kernels::log1m_device_function,
318325
opencl_kernels::phi_device_function,
319326
opencl_kernels::inv_phi_device_function)
@@ -352,10 +359,53 @@ ADD_BINARY_FUNCTION_WITH_INCLUDES(
352359
stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
353360
stan::math::opencl_kernels::lbeta_device_function,
354361
stan::math::opencl_kernels::binomial_coefficient_log_device_function)
355-
ADD_BINARY_FUNCTION_WITH_INCLUDES(
356-
lbeta, stan::math::opencl_kernels::lgamma_stirling_device_function,
362+
template <typename T1, typename T2>
363+
class lbeta_ : public elt_function_cl<lbeta_<T1, T2>, double, T1, T2> {
364+
using base = elt_function_cl<lbeta_<T1, T2>, double, T1, T2>;
365+
using base::arguments_;
366+
367+
public:
368+
using base::rows;
369+
using base::cols;
370+
static const std::vector<const char*> includes;
371+
explicit lbeta_(T1&& a, T2&& b)
372+
: base("stan_lbeta", std::forward<T1>(a), std::forward<T2>(b)) {
373+
if (a.rows() != base::dynamic && b.rows() != base::dynamic) {
374+
check_size_match("lbeta", "Rows of ", "a", a.rows(), "rows of ", "b",
375+
b.rows());
376+
}
377+
if (a.cols() != base::dynamic && b.cols() != base::dynamic) {
378+
check_size_match("lbeta", "Columns of ", "a", a.cols(), "columns of ",
379+
"b", b.cols());
380+
}
381+
}
382+
inline auto deep_copy() const {
383+
auto&& arg1_copy = this->template get_arg<0>().deep_copy();
384+
auto&& arg2_copy = this->template get_arg<1>().deep_copy();
385+
return lbeta_<std::remove_reference_t<decltype(arg1_copy)>,
386+
std::remove_reference_t<decltype(arg2_copy)>>{
387+
std::move(arg1_copy), std::move(arg2_copy)};
388+
}
389+
inline std::pair<int, int> extreme_diagonals() const {
390+
return {-rows() + 1, cols() - 1};
391+
}
392+
};
393+
394+
template <typename T1, typename T2,
395+
require_all_kernel_expressions_t<T1, T2>* = nullptr,
396+
require_any_not_stan_scalar_t<T1, T2>* = nullptr>
397+
inline lbeta_<as_operation_cl_t<T1>, as_operation_cl_t<T2>> lbeta(T1&& a,
398+
T2&& b) {
399+
return lbeta_<as_operation_cl_t<T1>, as_operation_cl_t<T2>>(
400+
as_operation_cl(std::forward<T1>(a)),
401+
as_operation_cl(std::forward<T2>(b)));
402+
}
403+
404+
template <typename T1, typename T2>
405+
const std::vector<const char*> lbeta_<T1, T2>::includes{
406+
stan::math::opencl_kernels::lgamma_stirling_device_function,
357407
stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
358-
stan::math::opencl_kernels::lbeta_device_function)
408+
stan::math::opencl_kernels::lbeta_device_function};
359409
ADD_BINARY_FUNCTION_WITH_INCLUDES(
360410
log_inv_logit_diff, opencl_kernels::log1p_exp_device_function,
361411
opencl_kernels::log1m_exp_device_function,

stan/math/opencl/kernels/device_functions/binomial_coefficient_log.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ static constexpr const char* binomial_coefficient_log_device_function
9595
} else if (n_plus_1 < LGAMMA_STIRLING_DIFF_USEFUL) {
9696
return lgamma(n_plus_1) - lgamma(k + 1) - lgamma(n_plus_1_mk);
9797
} else {
98-
return -lbeta(n_plus_1_mk, k + 1) - log1p(n);
98+
return -stan_lbeta(n_plus_1_mk, k + 1) - log1p(n);
9999
}
100100
}
101101
// \cond

stan/math/opencl/kernels/device_functions/lbeta.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ static constexpr const char* lbeta_device_function
5959
* @param b Second value
6060
* @return Log of the beta function applied to the two values.
6161
*/
62-
double lbeta(double a, double b) {
62+
double stan_lbeta(double a, double b) {
6363
if (isnan(a) || isnan(b)) {
64-
return a;
64+
return NAN;
6565
}
6666

6767
double x; // x is the smaller of the two
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_STD_NORMAL_LCDF_HPP
2+
#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_STD_NORMAL_LCDF_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/stringify.hpp>
6+
#include <string>
7+
8+
namespace stan {
9+
namespace math {
10+
namespace opencl_kernels {
11+
// \cond
12+
static constexpr const char* std_normal_lcdf_device_function
13+
= "\n"
14+
"#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_STD_NORMAL_LCDF\n"
15+
"#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_STD_NORMAL_LCDF\n"
16+
STRINGIFY(
17+
/** \ingroup opencl_kernels
18+
* Return the log standard normal cumulative distribution function
19+
* evaluated from the scaled input `x / sqrt(2)`.
20+
*
21+
* @param scaled_y input scaled by `1 / sqrt(2)`
22+
* @return log(Phi(x))
23+
*/
24+
inline double std_normal_lcdf_scaled_impl(double scaled_y) {
25+
double lcdf_n;
26+
if (scaled_y > 0.0) {
27+
// CDF(x) = 1/2 + 1/2 erf(x) = 1 - 1/2 erfc(x)
28+
lcdf_n = log1p(-0.5 * erfc(scaled_y));
29+
if (isnan(lcdf_n)) {
30+
lcdf_n = 0;
31+
}
32+
} else if (scaled_y > -20.0) {
33+
// CDF(x) = 1/2 - 1/2 erf(-x) = 1/2 erfc(-x)
34+
lcdf_n = log(erfc(-scaled_y)) - M_LN2;
35+
} else if (10.0 * log(fabs(scaled_y)) < log(DBL_MAX)) {
36+
// Need direct approximation once erfc(-x) underflows.
37+
const double x2 = scaled_y * scaled_y;
38+
const double x4 = pow(scaled_y, 4);
39+
const double x6 = pow(scaled_y, 6);
40+
const double x8 = pow(scaled_y, 8);
41+
const double x10 = pow(scaled_y, 10);
42+
const double temp_p
43+
= 0.000658749161529837803157 + 0.0160837851487422766278 / x2
44+
+ 0.125781726111229246204 / x4
45+
+ 0.360344899949804439429 / x6
46+
+ 0.305326634961232344035 / x8
47+
+ 0.0163153871373020978498 / x10;
48+
const double temp_q
49+
= -0.00233520497626869185443
50+
- 0.0605183413124413191178 / x2
51+
- 0.527905102951428412248 / x4
52+
- 1.87295284992346047209 / x6
53+
- 2.56852019228982242072 / x8 - 1.0 / x10;
54+
lcdf_n = log(0.5 * M_2_SQRTPI + (temp_p / temp_q) / x2)
55+
- M_LN2 - log(-scaled_y) - x2;
56+
} else {
57+
lcdf_n = -INFINITY;
58+
}
59+
return lcdf_n;
60+
}
61+
62+
/** \ingroup opencl_kernels
63+
* Return the derivative of log standard normal cumulative
64+
* distribution function with respect to the scaled input
65+
* `x / sqrt(2)`.
66+
*
67+
* @param scaled_y input scaled by `1 / sqrt(2)`
68+
* @return d / d(scaled_y) log(Phi(x))
69+
*/
70+
inline double std_normal_lcdf_dscaled_impl(double scaled_y) {
71+
double dnlcdf = 0.0;
72+
double t = 0.0;
73+
double t2 = 0.0;
74+
double t4 = 0.0;
75+
const double x2 = scaled_y * scaled_y;
76+
77+
if (scaled_y > 2.9) {
78+
t = 1.0 / (1.0 + 0.3275911 * scaled_y);
79+
t2 = t * t;
80+
t4 = pow(t, 4);
81+
dnlcdf
82+
= 0.5 * M_2_SQRTPI
83+
/ (exp(x2) - 0.254829592 + 0.284496736 * t
84+
- 1.421413741 * t2 + 1.453152027 * t2 * t
85+
- 1.061405429 * t4);
86+
} else if (scaled_y > 2.5) {
87+
t = scaled_y - 2.7;
88+
t2 = t * t;
89+
t4 = pow(t, 4);
90+
dnlcdf = 0.0003849882382 - 0.002079084702 * t
91+
+ 0.005229340880 * t2 - 0.008029540137 * t2 * t
92+
+ 0.008232190507 * t4 - 0.005692364250 * t4 * t
93+
+ 0.002399496363 * pow(t, 6);
94+
} else if (scaled_y > 2.1) {
95+
t = scaled_y - 2.3;
96+
t2 = t * t;
97+
t4 = pow(t, 4);
98+
dnlcdf = 0.002846135439 - 0.01310032351 * t
99+
+ 0.02732189391 * t2 - 0.03326906904 * t2 * t
100+
+ 0.02482478940 * t4 - 0.009883071924 * t4 * t
101+
- 0.0002771362254 * pow(t, 6);
102+
} else if (scaled_y > 1.5) {
103+
t = scaled_y - 1.85;
104+
t2 = t * t;
105+
t4 = pow(t, 4);
106+
dnlcdf = 0.01849212058 - 0.06876280470 * t
107+
+ 0.1099906382 * t2 - 0.09274533184 * t2 * t
108+
+ 0.03543327418 * t4 + 0.005644855518 * t4 * t
109+
- 0.01111434424 * pow(t, 6);
110+
} else if (scaled_y > 0.8) {
111+
t = scaled_y - 1.15;
112+
t2 = t * t;
113+
t4 = pow(t, 4);
114+
dnlcdf = 0.1585747034 - 0.3898677543 * t
115+
+ 0.3515963775 * t2 - 0.09748053605 * t2 * t
116+
- 0.04347986191 * t4 + 0.02182506378 * t4 * t
117+
+ 0.01074751427 * pow(t, 6);
118+
} else if (scaled_y > 0.1) {
119+
t = scaled_y - 0.45;
120+
t2 = t * t;
121+
t4 = pow(t, 4);
122+
dnlcdf = 0.6245634904 - 0.9521866949 * t
123+
+ 0.3986215682 * t2 + 0.04700850676 * t2 * t
124+
- 0.03478651979 * t4 - 0.01772675404 * t4 * t
125+
+ 0.0006577254811 * pow(t, 6);
126+
} else if (10.0 * log(fabs(scaled_y)) < log(DBL_MAX)) {
127+
t = 1.0 / (1.0 - 0.3275911 * scaled_y);
128+
t2 = t * t;
129+
t4 = pow(t, 4);
130+
dnlcdf
131+
= M_2_SQRTPI
132+
/ (0.254829592 * t - 0.284496736 * t2
133+
+ 1.421413741 * t2 * t - 1.453152027 * t4
134+
+ 1.061405429 * t4 * t);
135+
if (scaled_y < -29.0) {
136+
dnlcdf += 0.0015065154280332 * x2 - 0.3993154819705530 * scaled_y
137+
- 4.2919418242931700;
138+
} else if (scaled_y < -17.0) {
139+
dnlcdf += 0.0001263257217272 * x2 * scaled_y
140+
+ 0.0123586859488623 * x2
141+
- 0.0860505264736028 * scaled_y - 1.252783383752970;
142+
} else if (scaled_y < -7.0) {
143+
dnlcdf += 0.000471585349920831 * x2 * scaled_y
144+
+ 0.0296839305424034 * x2
145+
+ 0.207402143352332 * scaled_y + 0.425316974683324;
146+
} else if (scaled_y < -3.9) {
147+
dnlcdf += -0.0006972280656443 * x2 * scaled_y
148+
+ 0.0068218494628567 * x2
149+
+ 0.0585761964460277 * scaled_y + 0.1034397670201370;
150+
} else if (scaled_y < -2.1) {
151+
dnlcdf += -0.0018742199480885 * x2 * scaled_y
152+
- 0.0097119598291202 * x2
153+
- 0.0170137970924080 * scaled_y - 0.0100428567412041;
154+
}
155+
} else {
156+
dnlcdf = INFINITY;
157+
}
158+
159+
return dnlcdf;
160+
}) "\n#endif\n"; // NOLINT
161+
// \endcond
162+
163+
} // namespace opencl_kernels
164+
} // namespace math
165+
} // namespace stan
166+
167+
#endif
168+
#endif

stan/math/opencl/prim/exp_mod_normal_lcdf.hpp

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <stan/math/prim/fun/elt_divide.hpp>
99
#include <stan/math/prim/fun/elt_multiply.hpp>
1010
#include <stan/math/opencl/kernel_generator.hpp>
11+
#include <stan/math/opencl/prim/std_normal_lcdf.hpp>
1112
#include <stan/math/prim/functor/partials_propagator.hpp>
1213

1314
namespace stan {
@@ -80,30 +81,76 @@ exp_mod_normal_lcdf(const T_y_cl& y, const T_loc_cl& mu,
8081
auto scaled_diff = elt_multiply(diff * INV_SQRT_TWO, sigma_inv);
8182
auto v = elt_multiply(lambda_val, sigma_val);
8283
auto scaled_diff_diff = scaled_diff - v * INV_SQRT_TWO;
83-
auto erf_calc = 0.5 * (1.0 + erf(scaled_diff_diff));
84-
auto exp_term = exp(0.5 * square(v) - elt_multiply(lambda_val, diff));
85-
auto cdf_n = 0.5 + 0.5 * erf(scaled_diff) - elt_multiply(exp_term, erf_calc);
86-
auto cdf_log_expr = colwise_sum(log(cdf_n));
84+
auto cdf_term_1 = 0.5 + 0.5 * erf(scaled_diff);
85+
auto cdf_term_2_phi = 0.5 * (1.0 + erf(scaled_diff_diff));
86+
auto log_exp_term = 0.5 * square(v) - elt_multiply(lambda_val, diff);
87+
auto exp_term = exp(log_exp_term);
88+
auto cdf_term_2 = elt_multiply(exp_term, cdf_term_2_phi);
89+
auto cdf_n = cdf_term_1 - cdf_term_2;
90+
auto use_stable = cdf_n <= 0.0 || !isfinite(cdf_n);
8791

8892
auto exp_term_2 = exp(-square(scaled_diff_diff));
89-
auto deriv_1 = elt_multiply(elt_multiply(lambda_val, exp_term), erf_calc);
93+
auto deriv_1 = elt_multiply(elt_multiply(lambda_val, exp_term), cdf_term_2_phi);
9094
auto deriv_2 = INV_SQRT_TWO_PI
9195
* elt_multiply(elt_multiply(exp_term, exp_term_2), sigma_inv);
9296
auto deriv_3
9397
= INV_SQRT_TWO_PI * elt_multiply(exp(-square(scaled_diff)), sigma_inv);
94-
auto y_deriv = elt_divide(deriv_1 - deriv_2 + deriv_3, cdf_n);
95-
auto mu_deriv = -y_deriv;
96-
auto sigma_deriv = -elt_divide(
98+
auto direct_cdf_log = log(cdf_n);
99+
auto direct_y_deriv = elt_divide(deriv_1 - deriv_2 + deriv_3, cdf_n);
100+
auto direct_mu_deriv = -direct_y_deriv;
101+
auto direct_sigma_deriv = -elt_divide(
97102
elt_multiply(deriv_1 - deriv_2, v)
98103
+ elt_multiply(deriv_3 - deriv_2, scaled_diff) * SQRT_TWO,
99104
cdf_n);
100-
auto lambda_deriv = elt_divide(
105+
auto direct_lambda_deriv = elt_divide(
101106
elt_multiply(
102107
exp_term,
103108
INV_SQRT_TWO_PI * elt_multiply(sigma_val, exp_term_2)
104-
- elt_multiply(elt_multiply(v, sigma_val) - diff, erf_calc)),
109+
- elt_multiply(elt_multiply(v, sigma_val) - diff, cdf_term_2_phi)),
105110
cdf_n);
106111

112+
auto log_cdf_term_1 = std_normal_lcdf_scaled_impl(scaled_diff);
113+
auto dlog_cdf_term_1 = std_normal_lcdf_dscaled_impl(scaled_diff);
114+
auto log_cdf_term_2_phi = std_normal_lcdf_scaled_impl(scaled_diff_diff);
115+
auto dlog_cdf_term_2_phi = std_normal_lcdf_dscaled_impl(scaled_diff_diff);
116+
auto log_cdf_term_2 = log_exp_term + log_cdf_term_2_phi;
117+
auto log_cdf_n = log_diff_exp(log_cdf_term_1, log_cdf_term_2);
118+
auto cdf_term_1_weight = exp(log_cdf_term_1 - log_cdf_n);
119+
auto cdf_term_2_weight = exp(log_cdf_term_2 - log_cdf_n);
120+
auto scaled_diff_deriv
121+
= elt_multiply(dlog_cdf_term_1, sigma_inv * INV_SQRT_TWO);
122+
auto scaled_diff_diff_deriv
123+
= elt_multiply(dlog_cdf_term_2_phi, sigma_inv * INV_SQRT_TWO);
124+
auto stable_y_deriv = elt_multiply(cdf_term_1_weight, scaled_diff_deriv)
125+
- elt_multiply(cdf_term_2_weight,
126+
-lambda_val + scaled_diff_diff_deriv);
127+
auto stable_mu_deriv = -stable_y_deriv;
128+
auto stable_sigma_deriv = elt_multiply(
129+
cdf_term_1_weight,
130+
-elt_multiply(dlog_cdf_term_1,
131+
elt_multiply(scaled_diff,
132+
sigma_inv)))
133+
- elt_multiply(
134+
cdf_term_2_weight,
135+
elt_multiply(lambda_val, v)
136+
- elt_multiply(
137+
dlog_cdf_term_2_phi,
138+
elt_multiply(
139+
scaled_diff + v * INV_SQRT_TWO,
140+
sigma_inv)));
141+
auto stable_lambda_deriv
142+
= -elt_multiply(cdf_term_2_weight,
143+
elt_multiply(v, sigma_val) - diff
144+
- elt_multiply(dlog_cdf_term_2_phi,
145+
sigma_val * INV_SQRT_TWO));
146+
auto cdf_log_expr = colwise_sum(select(use_stable, log_cdf_n, direct_cdf_log));
147+
auto y_deriv = select(use_stable, stable_y_deriv, direct_y_deriv);
148+
auto mu_deriv = select(use_stable, stable_mu_deriv, direct_mu_deriv);
149+
auto sigma_deriv
150+
= select(use_stable, stable_sigma_deriv, direct_sigma_deriv);
151+
auto lambda_deriv
152+
= select(use_stable, stable_lambda_deriv, direct_lambda_deriv);
153+
107154
matrix_cl<char> any_y_neg_inf_cl;
108155
matrix_cl<char> any_y_pos_inf_cl;
109156
matrix_cl<double> cdf_log_cl;

stan/math/opencl/prim/gumbel_cdf.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ inline return_type_t<T_y_cl, T_loc_cl, T_scale_cl> gumbel_cdf(
6767
auto exp_m_scaled_diff = exp(-scaled_diff);
6868
auto cdf_n = exp(-exp_m_scaled_diff);
6969
auto cdf_expr = colwise_prod(cdf_n);
70-
auto rep_deriv = elt_divide(exp(-scaled_diff - exp_m_scaled_diff),
71-
elt_multiply(beta_val, cdf_n));
70+
auto rep_deriv = elt_divide(exp_m_scaled_diff, beta_val);
7271

7372
matrix_cl<double> cdf_cl;
7473
matrix_cl<double> y_deriv_cl;

0 commit comments

Comments
 (0)