|
22 | 22 | #include <stan/math/prim/fun/log_gamma_q_dgamma.hpp> |
23 | 23 | #include <stan/math/prim/functor/partials_propagator.hpp> |
24 | 24 | #include <cmath> |
| 25 | +#include <optional> |
25 | 26 |
|
26 | 27 | namespace stan { |
27 | 28 | namespace math { |
28 | 29 | namespace internal { |
29 | | -template <typename T> |
30 | | -struct Q_eval { |
31 | | - T log_Q{0.0}; |
32 | | - T dlogQ_dalpha{0.0}; |
33 | | - bool ok{false}; |
34 | | -}; |
35 | 30 |
|
36 | 31 | /** |
37 | 32 | * Computes log q and d(log q) / d(alpha) using continued fraction. |
38 | 33 | */ |
39 | | -template <typename T, typename T_shape, bool any_fvar, bool partials_fvar> |
40 | | -static inline Q_eval<T> eval_q_cf(const T& alpha, const T& beta_y) { |
41 | | - Q_eval<T> out; |
| 34 | +template <bool any_fvar, bool partials_fvar, typename T_shape, typename T1, typename T2> |
| 35 | +inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>> |
| 36 | +eval_q_cf(const T1& alpha, const T2& beta_y) { |
| 37 | + using scalar_t = return_type_t<T1, T2>; |
| 38 | + using ret_t = std::pair<scalar_t, scalar_t>; |
42 | 39 | if constexpr (!any_fvar && is_autodiff_v<T_shape>) { |
43 | | - auto log_q_result |
| 40 | + std::pair<double, double> log_q_result |
44 | 41 | = log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y)); |
45 | | - out.log_Q = log_q_result.log_q; |
46 | | - out.dlogQ_dalpha = log_q_result.dlog_q_da; |
| 42 | + if (likely(std::isfinite(value_of_rec(log_q_result.first)))) { |
| 43 | + return std::optional{log_q_result}; |
| 44 | + } else { |
| 45 | + return std::optional<ret_t>{std::nullopt}; |
| 46 | + } |
47 | 47 | } else { |
48 | | - out.log_Q = internal::log_q_gamma_cf(alpha, beta_y); |
| 48 | + ret_t out{internal::log_q_gamma_cf(alpha, beta_y), 0.0}; |
| 49 | + if (unlikely(!std::isfinite(value_of_rec(out.first)))) { |
| 50 | + return std::optional<ret_t>{std::nullopt}; |
| 51 | + } |
49 | 52 | if constexpr (is_autodiff_v<T_shape>) { |
50 | 53 | if constexpr (!partials_fvar) { |
51 | | - out.dlogQ_dalpha |
| 54 | + out.second |
52 | 55 | = grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha)) |
53 | | - / exp(out.log_Q); |
| 56 | + / exp(out.first); |
54 | 57 | } else { |
55 | | - T alpha_unit = alpha; |
| 58 | + auto alpha_unit = alpha; |
56 | 59 | alpha_unit.d_ = 1; |
57 | | - T beta_y_unit = beta_y; |
| 60 | + auto beta_y_unit = beta_y; |
58 | 61 | beta_y_unit.d_ = 0; |
59 | | - T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit); |
60 | | - out.dlogQ_dalpha = log_Q_fvar.d_; |
| 62 | + auto log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit); |
| 63 | + out.second = log_Q_fvar.d_; |
61 | 64 | } |
62 | 65 | } |
| 66 | + return std::optional{out}; |
63 | 67 | } |
64 | | - |
65 | | - out.ok = std::isfinite(value_of_rec(out.log_Q)); |
66 | | - return out; |
67 | 68 | } |
68 | 69 |
|
69 | 70 | /** |
70 | 71 | * Computes log q and d(log q) / d(alpha) using log1m. |
71 | 72 | */ |
72 | | -template <typename T, typename T_shape, bool partials_fvar> |
73 | | -static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) { |
74 | | - Q_eval<T> out; |
75 | | - out.log_Q = log1m(gamma_p(alpha, beta_y)); |
76 | | - |
77 | | - if (!std::isfinite(value_of_rec(out.log_Q))) { |
78 | | - out.ok = false; |
79 | | - return out; |
| 73 | +template <bool partials_fvar, typename T_shape, typename T1, typename T2> |
| 74 | +inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>> |
| 75 | +eval_q_log1m(const T1& alpha, const T2& beta_y) { |
| 76 | + using scalar_t = return_type_t<T1, T2>; |
| 77 | + using ret_t = std::pair<scalar_t, scalar_t>; |
| 78 | + ret_t out{log1m(gamma_p(alpha, beta_y)), 0.0}; |
| 79 | + if (unlikely(!std::isfinite(value_of_rec(out.first)))) { |
| 80 | + return std::optional<ret_t>{std::nullopt}; |
80 | 81 | } |
81 | | - |
82 | 82 | if constexpr (is_autodiff_v<T_shape>) { |
83 | 83 | if constexpr (partials_fvar) { |
84 | | - T alpha_unit = alpha; |
| 84 | + auto alpha_unit = alpha; |
85 | 85 | alpha_unit.d_ = 1; |
86 | | - T beta_unit = beta_y; |
| 86 | + auto beta_unit = beta_y; |
87 | 87 | beta_unit.d_ = 0; |
88 | | - T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit)); |
89 | | - out.dlogQ_dalpha = log_Q_fvar.d_; |
| 88 | + auto log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit)); |
| 89 | + out.second = log_Q_fvar.d_; |
90 | 90 | } else { |
91 | | - out.dlogQ_dalpha |
92 | | - = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q); |
| 91 | + out.second |
| 92 | + = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.first); |
93 | 93 | } |
94 | 94 | } |
95 | | - |
96 | | - out.ok = true; |
97 | | - return out; |
| 95 | + return std::optional{out}; |
98 | 96 | } |
99 | 97 | } // namespace internal |
100 | 98 |
|
@@ -137,63 +135,56 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf( |
137 | 135 | for (size_t n = 0; n < N; n++) { |
138 | 136 | // Explicit results for extreme values |
139 | 137 | // The gradients are technically ill-defined, but treated as zero |
140 | | - const T_partials_return y_dbl = y_vec.val(n); |
141 | | - if (y_dbl == 0.0) { |
| 138 | + const T_partials_return y_val = y_vec.val(n); |
| 139 | + if (y_val == 0.0) { |
142 | 140 | continue; |
143 | 141 | } |
144 | | - if (y_dbl == INFTY) { |
| 142 | + if (y_val == INFTY) { |
145 | 143 | return ops_partials.build(negative_infinity()); |
146 | 144 | } |
147 | 145 |
|
148 | | - const T_partials_return alpha_dbl = alpha_vec.val(n); |
149 | | - const T_partials_return beta_dbl = beta_vec.val(n); |
| 146 | + const T_partials_return alpha_val = alpha_vec.val(n); |
| 147 | + const T_partials_return beta_val = beta_vec.val(n); |
150 | 148 |
|
151 | | - const T_partials_return beta_y = beta_dbl * y_dbl; |
| 149 | + const T_partials_return beta_y = beta_val * y_val; |
152 | 150 | if (beta_y == INFTY) { |
153 | 151 | return ops_partials.build(negative_infinity()); |
154 | 152 | } |
155 | | - |
156 | | - const bool use_continued_fraction = beta_y > alpha_dbl + 1.0; |
157 | | - internal::Q_eval<T_partials_return> result; |
158 | | - if (use_continued_fraction) { |
159 | | - result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar, |
160 | | - partials_fvar>(alpha_dbl, beta_y); |
| 153 | + std::optional<std::pair<T_partials_return, T_partials_return>> result; |
| 154 | + if (beta_y > alpha_val + 1.0) { |
| 155 | + result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y); |
161 | 156 | } else { |
162 | | - result |
163 | | - = internal::eval_q_log1m<T_partials_return, T_shape, partials_fvar>( |
164 | | - alpha_dbl, beta_y); |
165 | | - |
166 | | - if (!result.ok && beta_y > 0.0) { |
| 157 | + result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_val, beta_y); |
| 158 | + if (!result && beta_y > 0.0) { |
167 | 159 | // Fallback to continued fraction if log1m fails |
168 | | - result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar, |
169 | | - partials_fvar>(alpha_dbl, beta_y); |
| 160 | + result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y); |
170 | 161 | } |
171 | 162 | } |
172 | | - if (!result.ok) { |
| 163 | + if (unlikely(!result)) { |
173 | 164 | return ops_partials.build(negative_infinity()); |
174 | 165 | } |
175 | 166 |
|
176 | | - P += result.log_Q; |
| 167 | + P += result->first; |
177 | 168 |
|
178 | 169 | if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) { |
179 | | - const T_partials_return log_y = log(y_dbl); |
180 | | - const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y); |
| 170 | + const T_partials_return log_y = log(y_val); |
| 171 | + const T_partials_return alpha_minus_one = fma(alpha_val, log_y, -log_y); |
181 | 172 |
|
182 | | - const T_partials_return log_pdf = alpha_dbl * log(beta_dbl) |
183 | | - - lgamma(alpha_dbl) + alpha_minus_one |
| 173 | + const T_partials_return log_pdf = alpha_val * log(beta_val) |
| 174 | + - lgamma(alpha_val) + alpha_minus_one |
184 | 175 | - beta_y; |
185 | 176 |
|
186 | | - const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q |
| 177 | + const T_partials_return hazard = exp(log_pdf - result->first); // f/Q |
187 | 178 |
|
188 | 179 | if constexpr (is_autodiff_v<T_y>) { |
189 | 180 | partials<0>(ops_partials)[n] -= hazard; |
190 | 181 | } |
191 | 182 | if constexpr (is_autodiff_v<T_inv_scale>) { |
192 | | - partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard; |
| 183 | + partials<2>(ops_partials)[n] -= (y_val / beta_val) * hazard; |
193 | 184 | } |
194 | 185 | } |
195 | 186 | if constexpr (is_autodiff_v<T_shape>) { |
196 | | - partials<1>(ops_partials)[n] += result.dlogQ_dalpha; |
| 187 | + partials<1>(ops_partials)[n] += result->second; |
197 | 188 | } |
198 | 189 | } |
199 | 190 | return ops_partials.build(P); |
|
0 commit comments