Skip to content

Commit 15e00a1

Browse files
authored
Merge pull request #3292 from stan-dev/fix/gamma-lccdf-v3-review
Fix/gamma lccdf v3 review
2 parents 86ac561 + 60acdd4 commit 15e00a1

File tree

2 files changed

+87
-129
lines changed

2 files changed

+87
-129
lines changed

stan/math/prim/fun/log_gamma_q_dgamma.hpp

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/prim/fun/constants.hpp>
66
#include <stan/math/prim/fun/digamma.hpp>
77
#include <stan/math/prim/fun/exp.hpp>
8+
#include <stan/math/prim/fun/fabs.hpp>
89
#include <stan/math/prim/fun/gamma_p.hpp>
910
#include <stan/math/prim/fun/gamma_q.hpp>
1011
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
@@ -19,17 +20,6 @@
1920
namespace stan {
2021
namespace math {
2122

22-
/**
23-
* Result structure containing log(Q(a,z)) and its gradient with respect to a.
24-
*
25-
* @tparam T return type
26-
*/
27-
template <typename T>
28-
struct log_gamma_q_result {
29-
T log_q; ///< log(Q(a,z)) where Q is upper regularized incomplete gamma
30-
T dlog_q_da; ///< d/da log(Q(a,z))
31-
};
32-
3323
namespace internal {
3424

3525
/**
@@ -40,50 +30,36 @@ namespace internal {
4030
* @tparam T_z Type of value parameter z (double or fvar types)
4131
* @param a Shape parameter
4232
* @param z Value at which to evaluate
43-
* @param precision Convergence threshold
33+
* @param precision Convergence threshold, default of sqrt(machine_epsilon)
4434
* @param max_steps Maximum number of continued fraction iterations
4535
* @return log(Q(a,z)) with the return type of T_a and T_z
4636
*/
4737
template <typename T_a, typename T_z>
4838
inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
49-
double precision = 1e-16,
39+
double precision = 1.49012e-08,
5040
int max_steps = 250) {
51-
using stan::math::lgamma;
52-
using stan::math::log;
53-
using stan::math::value_of;
54-
using std::fabs;
5541
using T_return = return_type_t<T_a, T_z>;
56-
5742
const T_return log_prefactor = a * log(z) - z - lgamma(a);
5843

59-
T_return b = z + 1.0 - a;
60-
T_return C = (fabs(value_of(b)) >= EPSILON) ? b : T_return(EPSILON);
44+
T_return b_init = z + 1.0 - a;
45+
T_return C = (fabs(value_of(b_init)) >= EPSILON) ? b_init : std::decay_t<decltype(b_init)>(EPSILON);
6146
T_return D = 0.0;
6247
T_return f = C;
63-
6448
for (int i = 1; i <= max_steps; ++i) {
6549
T_a an = -i * (i - a);
66-
b += 2.0;
67-
50+
const T_return b = b_init + 2.0 * i;
6851
D = b + an * D;
69-
if (fabs(D) < EPSILON) {
70-
D = EPSILON;
71-
}
52+
D = (fabs(value_of(D)) >= EPSILON) ? D : std::decay_t<decltype(D)>(EPSILON);
7253
C = b + an / C;
73-
if (fabs(C) < EPSILON) {
74-
C = EPSILON;
75-
}
76-
54+
C = (fabs(value_of(C)) >= EPSILON) ? C : std::decay_t<decltype(C)>(EPSILON);
7755
D = inv(D);
78-
T_return delta = C * D;
56+
const T_return delta = C * D;
7957
f *= delta;
80-
81-
const double delta_m1 = value_of(fabs(value_of(delta) - 1.0));
58+
const double delta_m1 = fabs(value_of(delta) - 1.0);
8259
if (delta_m1 < precision) {
8360
break;
8461
}
8562
}
86-
8763
return log_prefactor - log(f);
8864
}
8965

@@ -102,52 +78,43 @@ inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
10278
* @tparam T_z type of the value parameter
10379
* @param a shape parameter (must be positive)
10480
* @param z value parameter (must be non-negative)
105-
* @param precision convergence threshold
81+
* @param precision convergence threshold, default of sqrt(machine_epsilon)
10682
* @param max_steps maximum iterations for continued fraction
10783
* @return structure containing log(Q(a,z)) and d/da log(Q(a,z))
10884
*/
10985
template <typename T_a, typename T_z>
110-
inline log_gamma_q_result<return_type_t<T_a, T_z>> log_gamma_q_dgamma(
111-
const T_a& a, const T_z& z, double precision = 1e-16, int max_steps = 250) {
112-
using std::exp;
113-
using std::log;
86+
inline std::pair<return_type_t<T_a, T_z>, return_type_t<T_a, T_z>> log_gamma_q_dgamma(
87+
const T_a& a, const T_z& z, double precision = 1.49012e-08, int max_steps = 250) {
11488
using T_return = return_type_t<T_a, T_z>;
115-
116-
const double a_dbl = value_of(a);
117-
const double z_dbl = value_of(z);
118-
119-
log_gamma_q_result<T_return> result;
120-
89+
const double a_val = value_of(a);
90+
const double z_val = value_of(z);
12191
// For z > a + 1, use continued fraction for better numerical stability
122-
if (z_dbl > a_dbl + 1.0) {
123-
result.log_q = internal::log_q_gamma_cf(a_dbl, z_dbl, precision, max_steps);
124-
92+
if (z_val > a_val + 1.0) {
93+
std::pair<T_return, T_return> result{internal::log_q_gamma_cf(a_val, z_val, precision, max_steps), 0.0};
12594
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
12695
// grad_reg_inc_gamma computes dQ/da
127-
const double Q_val = exp(result.log_q);
96+
const T_return Q_val = exp(result.first);
12897
const double dQ_da
129-
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
130-
result.dlog_q_da = dQ_da / Q_val;
131-
98+
= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
99+
result.second = dQ_da / Q_val;
100+
return result;
132101
} else {
133102
// For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy
134-
const double P_val = gamma_p(a_dbl, z_dbl);
135-
result.log_q = log1m(P_val);
136-
103+
const double P_val = gamma_p(a_val, z_val);
104+
std::pair<T_return, T_return> result{log1m(P_val), 0.0};
137105
// Gradient: d/da log(Q) = (1/Q) * dQ/da
138106
// grad_reg_inc_gamma computes dQ/da
139-
const double Q_val = exp(result.log_q);
107+
const T_return Q_val = exp(result.first);
140108
if (Q_val > 0) {
141109
const double dQ_da
142-
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
143-
result.dlog_q_da = dQ_da / Q_val;
110+
= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
111+
result.second = dQ_da / Q_val;
144112
} else {
145113
// Fallback if Q rounds to zero - use asymptotic approximation
146-
result.dlog_q_da = log(z_dbl) - digamma(a_dbl);
114+
result.second = log(z_val) - digamma(a_val);
147115
}
116+
return result;
148117
}
149-
150-
return result;
151118
}
152119

153120
} // namespace math

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 59 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -22,79 +22,77 @@
2222
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
2323
#include <stan/math/prim/functor/partials_propagator.hpp>
2424
#include <cmath>
25+
#include <optional>
2526

2627
namespace stan {
2728
namespace math {
2829
namespace internal {
29-
template <typename T>
30-
struct Q_eval {
31-
T log_Q{0.0};
32-
T dlogQ_dalpha{0.0};
33-
bool ok{false};
34-
};
3530

3631
/**
3732
* Computes log q and d(log q) / d(alpha) using continued fraction.
3833
*/
39-
template <typename T, typename T_shape, bool any_fvar, bool partials_fvar>
40-
static inline Q_eval<T> eval_q_cf(const T& alpha, const T& beta_y) {
41-
Q_eval<T> out;
34+
template <bool any_fvar, bool partials_fvar, typename T_shape, typename T1, typename T2>
35+
inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>>
36+
eval_q_cf(const T1& alpha, const T2& beta_y) {
37+
using scalar_t = return_type_t<T1, T2>;
38+
using ret_t = std::pair<scalar_t, scalar_t>;
4239
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
43-
auto log_q_result
40+
std::pair<double, double> log_q_result
4441
= log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
45-
out.log_Q = log_q_result.log_q;
46-
out.dlogQ_dalpha = log_q_result.dlog_q_da;
42+
if (likely(std::isfinite(value_of_rec(log_q_result.first)))) {
43+
return std::optional{log_q_result};
44+
} else {
45+
return std::optional<ret_t>{std::nullopt};
46+
}
4747
} else {
48-
out.log_Q = internal::log_q_gamma_cf(alpha, beta_y);
48+
ret_t out{internal::log_q_gamma_cf(alpha, beta_y), 0.0};
49+
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
50+
return std::optional<ret_t>{std::nullopt};
51+
}
4952
if constexpr (is_autodiff_v<T_shape>) {
5053
if constexpr (!partials_fvar) {
51-
out.dlogQ_dalpha
54+
out.second
5255
= grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha))
53-
/ exp(out.log_Q);
56+
/ exp(out.first);
5457
} else {
55-
T alpha_unit = alpha;
58+
auto alpha_unit = alpha;
5659
alpha_unit.d_ = 1;
57-
T beta_y_unit = beta_y;
60+
auto beta_y_unit = beta_y;
5861
beta_y_unit.d_ = 0;
59-
T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
60-
out.dlogQ_dalpha = log_Q_fvar.d_;
62+
auto log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
63+
out.second = log_Q_fvar.d_;
6164
}
6265
}
66+
return std::optional{out};
6367
}
64-
65-
out.ok = std::isfinite(value_of_rec(out.log_Q));
66-
return out;
6768
}
6869

6970
/**
7071
* Computes log q and d(log q) / d(alpha) using log1m.
7172
*/
72-
template <typename T, typename T_shape, bool partials_fvar>
73-
static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) {
74-
Q_eval<T> out;
75-
out.log_Q = log1m(gamma_p(alpha, beta_y));
76-
77-
if (!std::isfinite(value_of_rec(out.log_Q))) {
78-
out.ok = false;
79-
return out;
73+
template <bool partials_fvar, typename T_shape, typename T1, typename T2>
74+
inline std::optional<std::pair<return_type_t<T1, T2>, return_type_t<T1, T2>>>
75+
eval_q_log1m(const T1& alpha, const T2& beta_y) {
76+
using scalar_t = return_type_t<T1, T2>;
77+
using ret_t = std::pair<scalar_t, scalar_t>;
78+
ret_t out{log1m(gamma_p(alpha, beta_y)), 0.0};
79+
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
80+
return std::optional<ret_t>{std::nullopt};
8081
}
81-
8282
if constexpr (is_autodiff_v<T_shape>) {
8383
if constexpr (partials_fvar) {
84-
T alpha_unit = alpha;
84+
auto alpha_unit = alpha;
8585
alpha_unit.d_ = 1;
86-
T beta_unit = beta_y;
86+
auto beta_unit = beta_y;
8787
beta_unit.d_ = 0;
88-
T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
89-
out.dlogQ_dalpha = log_Q_fvar.d_;
88+
auto log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
89+
out.second = log_Q_fvar.d_;
9090
} else {
91-
out.dlogQ_dalpha
92-
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q);
91+
out.second
92+
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.first);
9393
}
9494
}
95-
96-
out.ok = true;
97-
return out;
95+
return std::optional{out};
9896
}
9997
} // namespace internal
10098

@@ -137,63 +135,56 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
137135
for (size_t n = 0; n < N; n++) {
138136
// Explicit results for extreme values
139137
// The gradients are technically ill-defined, but treated as zero
140-
const T_partials_return y_dbl = y_vec.val(n);
141-
if (y_dbl == 0.0) {
138+
const T_partials_return y_val = y_vec.val(n);
139+
if (y_val == 0.0) {
142140
continue;
143141
}
144-
if (y_dbl == INFTY) {
142+
if (y_val == INFTY) {
145143
return ops_partials.build(negative_infinity());
146144
}
147145

148-
const T_partials_return alpha_dbl = alpha_vec.val(n);
149-
const T_partials_return beta_dbl = beta_vec.val(n);
146+
const T_partials_return alpha_val = alpha_vec.val(n);
147+
const T_partials_return beta_val = beta_vec.val(n);
150148

151-
const T_partials_return beta_y = beta_dbl * y_dbl;
149+
const T_partials_return beta_y = beta_val * y_val;
152150
if (beta_y == INFTY) {
153151
return ops_partials.build(negative_infinity());
154152
}
155-
156-
const bool use_continued_fraction = beta_y > alpha_dbl + 1.0;
157-
internal::Q_eval<T_partials_return> result;
158-
if (use_continued_fraction) {
159-
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
160-
partials_fvar>(alpha_dbl, beta_y);
153+
std::optional<std::pair<T_partials_return, T_partials_return>> result;
154+
if (beta_y > alpha_val + 1.0) {
155+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y);
161156
} else {
162-
result
163-
= internal::eval_q_log1m<T_partials_return, T_shape, partials_fvar>(
164-
alpha_dbl, beta_y);
165-
166-
if (!result.ok && beta_y > 0.0) {
157+
result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_val, beta_y);
158+
if (!result && beta_y > 0.0) {
167159
// Fallback to continued fraction if log1m fails
168-
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
169-
partials_fvar>(alpha_dbl, beta_y);
160+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y);
170161
}
171162
}
172-
if (!result.ok) {
163+
if (unlikely(!result)) {
173164
return ops_partials.build(negative_infinity());
174165
}
175166

176-
P += result.log_Q;
167+
P += result->first;
177168

178169
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
179-
const T_partials_return log_y = log(y_dbl);
180-
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
170+
const T_partials_return log_y = log(y_val);
171+
const T_partials_return alpha_minus_one = fma(alpha_val, log_y, -log_y);
181172

182-
const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
183-
- lgamma(alpha_dbl) + alpha_minus_one
173+
const T_partials_return log_pdf = alpha_val * log(beta_val)
174+
- lgamma(alpha_val) + alpha_minus_one
184175
- beta_y;
185176

186-
const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q
177+
const T_partials_return hazard = exp(log_pdf - result->first); // f/Q
187178

188179
if constexpr (is_autodiff_v<T_y>) {
189180
partials<0>(ops_partials)[n] -= hazard;
190181
}
191182
if constexpr (is_autodiff_v<T_inv_scale>) {
192-
partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
183+
partials<2>(ops_partials)[n] -= (y_val / beta_val) * hazard;
193184
}
194185
}
195186
if constexpr (is_autodiff_v<T_shape>) {
196-
partials<1>(ops_partials)[n] += result.dlogQ_dalpha;
187+
partials<1>(ops_partials)[n] += result->second;
197188
}
198189
}
199190
return ops_partials.build(P);

0 commit comments

Comments
 (0)