Skip to content

Commit bf9fc6b

Browse files
authored
Merge pull request #3266 from stan-dev/fix-gamma-lccdf-v2
super stable gamma_lccdf
2 parents 5b7cef4 + eb92b73 commit bf9fc6b

File tree

4 files changed

+430
-39
lines changed

4 files changed

+430
-39
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP
2+
#define STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/digamma.hpp>
7+
#include <stan/math/prim/fun/exp.hpp>
8+
#include <stan/math/prim/fun/gamma_p.hpp>
9+
#include <stan/math/prim/fun/gamma_q.hpp>
10+
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
11+
#include <stan/math/prim/fun/inv.hpp>
12+
#include <stan/math/prim/fun/lgamma.hpp>
13+
#include <stan/math/prim/fun/log.hpp>
14+
#include <stan/math/prim/fun/log1m.hpp>
15+
#include <stan/math/prim/fun/tgamma.hpp>
16+
#include <stan/math/prim/fun/value_of.hpp>
17+
#include <cmath>
18+
19+
namespace stan {
20+
namespace math {
21+
22+
/**
23+
* Result structure containing log(Q(a,z)) and its gradient with respect to a.
24+
*
25+
* @tparam T return type
26+
*/
27+
template <typename T>
28+
struct log_gamma_q_result {
29+
T log_q; ///< log(Q(a,z)) where Q is upper regularized incomplete gamma
30+
T dlog_q_da; ///< d/da log(Q(a,z))
31+
};
32+
33+
namespace internal {
34+
35+
/**
36+
* Compute log(Q(a,z)) using continued fraction expansion for upper incomplete
37+
* gamma function.
38+
*
39+
* @tparam T_a Type of shape parameter a (double or fvar types)
40+
* @tparam T_z Type of value parameter z (double or fvar types)
41+
* @param a Shape parameter
42+
* @param z Value at which to evaluate
43+
* @param precision Convergence threshold
44+
* @param max_steps Maximum number of continued fraction iterations
45+
* @return log(Q(a,z)) with the return type of T_a and T_z
46+
*/
47+
template <typename T_a, typename T_z>
48+
inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
49+
double precision = 1e-16,
50+
int max_steps = 250) {
51+
using stan::math::lgamma;
52+
using stan::math::log;
53+
using stan::math::value_of;
54+
using std::fabs;
55+
using T_return = return_type_t<T_a, T_z>;
56+
57+
const T_return log_prefactor = a * log(z) - z - lgamma(a);
58+
59+
T_return b = z + 1.0 - a;
60+
T_return C = (fabs(value_of(b)) >= EPSILON) ? b : T_return(EPSILON);
61+
T_return D = 0.0;
62+
T_return f = C;
63+
64+
for (int i = 1; i <= max_steps; ++i) {
65+
T_a an = -i * (i - a);
66+
b += 2.0;
67+
68+
D = b + an * D;
69+
if (fabs(D) < EPSILON) {
70+
D = EPSILON;
71+
}
72+
C = b + an / C;
73+
if (fabs(C) < EPSILON) {
74+
C = EPSILON;
75+
}
76+
77+
D = inv(D);
78+
T_return delta = C * D;
79+
f *= delta;
80+
81+
const double delta_m1 = value_of(fabs(value_of(delta) - 1.0));
82+
if (delta_m1 < precision) {
83+
break;
84+
}
85+
}
86+
87+
return log_prefactor - log(f);
88+
}
89+
90+
} // namespace internal
91+
92+
/**
93+
* Compute log(Q(a,z)) and its gradient with respect to a using continued
94+
* fraction expansion, where Q(a,z) = Gamma(a,z) / Gamma(a) is the regularized
95+
* upper incomplete gamma function.
96+
*
97+
* This uses a continued fraction representation for numerical stability when
98+
* computing the upper incomplete gamma function in log space, along with
99+
* analytical gradient computation.
100+
*
101+
* @tparam T_a type of the shape parameter
102+
* @tparam T_z type of the value parameter
103+
* @param a shape parameter (must be positive)
104+
* @param z value parameter (must be non-negative)
105+
* @param precision convergence threshold
106+
* @param max_steps maximum iterations for continued fraction
107+
* @return structure containing log(Q(a,z)) and d/da log(Q(a,z))
108+
*/
109+
template <typename T_a, typename T_z>
110+
inline log_gamma_q_result<return_type_t<T_a, T_z>> log_gamma_q_dgamma(
111+
const T_a& a, const T_z& z, double precision = 1e-16, int max_steps = 250) {
112+
using std::exp;
113+
using std::log;
114+
using T_return = return_type_t<T_a, T_z>;
115+
116+
const double a_dbl = value_of(a);
117+
const double z_dbl = value_of(z);
118+
119+
log_gamma_q_result<T_return> result;
120+
121+
// For z > a + 1, use continued fraction for better numerical stability
122+
if (z_dbl > a_dbl + 1.0) {
123+
result.log_q = internal::log_q_gamma_cf(a_dbl, z_dbl, precision, max_steps);
124+
125+
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
126+
// grad_reg_inc_gamma computes dQ/da
127+
const double Q_val = exp(result.log_q);
128+
const double dQ_da
129+
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
130+
result.dlog_q_da = dQ_da / Q_val;
131+
132+
} else {
133+
// For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy
134+
const double P_val = gamma_p(a_dbl, z_dbl);
135+
result.log_q = log1m(P_val);
136+
137+
// Gradient: d/da log(Q) = (1/Q) * dQ/da
138+
// grad_reg_inc_gamma computes dQ/da
139+
const double Q_val = exp(result.log_q);
140+
if (Q_val > 0) {
141+
const double dQ_da
142+
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
143+
result.dlog_q_da = dQ_da / Q_val;
144+
} else {
145+
// Fallback if Q rounds to zero - use asymptotic approximation
146+
result.dlog_q_da = log(z_dbl) - digamma(a_dbl);
147+
}
148+
}
149+
150+
return result;
151+
}
152+
153+
} // namespace math
154+
} // namespace stan
155+
156+
#endif

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 109 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@
66
#include <stan/math/prim/fun/constants.hpp>
77
#include <stan/math/prim/fun/digamma.hpp>
88
#include <stan/math/prim/fun/exp.hpp>
9-
#include <stan/math/prim/fun/gamma_q.hpp>
9+
#include <stan/math/prim/fun/fma.hpp>
10+
#include <stan/math/prim/fun/gamma_p.hpp>
1011
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
12+
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
13+
#include <stan/math/prim/fun/lgamma.hpp>
1114
#include <stan/math/prim/fun/log.hpp>
15+
#include <stan/math/prim/fun/log1m.hpp>
1216
#include <stan/math/prim/fun/max_size.hpp>
1317
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1418
#include <stan/math/prim/fun/size.hpp>
1519
#include <stan/math/prim/fun/size_zero.hpp>
1620
#include <stan/math/prim/fun/tgamma.hpp>
17-
#include <stan/math/prim/fun/value_of.hpp>
21+
#include <stan/math/prim/fun/value_of_rec.hpp>
22+
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
1823
#include <stan/math/prim/functor/partials_propagator.hpp>
1924
#include <cmath>
2025

@@ -24,10 +29,9 @@ namespace math {
2429
template <typename T_y, typename T_shape, typename T_inv_scale>
2530
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
2631
const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
27-
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
2832
using std::exp;
2933
using std::log;
30-
using std::pow;
34+
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
3135
using T_y_ref = ref_type_t<T_y>;
3236
using T_alpha_ref = ref_type_t<T_shape>;
3337
using T_beta_ref = ref_type_t<T_inv_scale>;
@@ -51,61 +55,127 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
5155
scalar_seq_view<T_y_ref> y_vec(y_ref);
5256
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
5357
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
54-
size_t N = max_size(y, alpha, beta);
55-
56-
// Explicit return for extreme values
57-
// The gradients are technically ill-defined, but treated as zero
58-
for (size_t i = 0; i < stan::math::size(y); i++) {
59-
if (y_vec.val(i) == 0) {
60-
// LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61-
return ops_partials.build(0.0);
62-
}
63-
}
58+
const size_t N = max_size(y, alpha, beta);
59+
60+
constexpr bool any_fvar = is_fvar<scalar_type_t<T_y>>::value
61+
|| is_fvar<scalar_type_t<T_shape>>::value
62+
|| is_fvar<scalar_type_t<T_inv_scale>>::value;
63+
constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
6464

6565
for (size_t n = 0; n < N; n++) {
6666
// Explicit results for extreme values
6767
// The gradients are technically ill-defined, but treated as zero
68-
if (y_vec.val(n) == INFTY) {
69-
// LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
68+
const T_partials_return y_dbl = y_vec.val(n);
69+
if (y_dbl == 0.0) {
70+
continue;
71+
}
72+
if (y_dbl == INFTY) {
7073
return ops_partials.build(negative_infinity());
7174
}
7275

73-
const T_partials_return y_dbl = y_vec.val(n);
7476
const T_partials_return alpha_dbl = alpha_vec.val(n);
7577
const T_partials_return beta_dbl = beta_vec.val(n);
76-
const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
7778

78-
// Qn = 1 - Pn
79-
const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl);
80-
const T_partials_return log_Qn = log(Qn);
79+
const T_partials_return beta_y = beta_dbl * y_dbl;
80+
if (beta_y == INFTY) {
81+
return ops_partials.build(negative_infinity());
82+
}
8183

84+
bool use_cf = beta_y > alpha_dbl + 1.0;
85+
T_partials_return log_Qn;
86+
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;
87+
88+
// Branch by autodiff type first, then handle use_cf logic inside each path
89+
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
90+
// var-only path: use log_gamma_q_dgamma which computes both log_q
91+
// and its gradient analytically with double inputs
92+
const double beta_y_dbl = value_of_rec(beta_y);
93+
const double alpha_dbl_val = value_of_rec(alpha_dbl);
94+
95+
if (use_cf) {
96+
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
97+
log_Qn = log_q_result.log_q;
98+
dlogQ_dalpha = log_q_result.dlog_q_da;
99+
} else {
100+
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
101+
log_Qn = log1m(Pn);
102+
const T_partials_return Qn = exp(log_Qn);
103+
104+
// Check if we need to fallback to continued fraction
105+
bool need_cf_fallback
106+
= !std::isfinite(value_of_rec(log_Qn)) || Qn <= 0.0;
107+
if (need_cf_fallback && beta_y > 0.0) {
108+
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
109+
log_Qn = log_q_result.log_q;
110+
dlogQ_dalpha = log_q_result.dlog_q_da;
111+
} else {
112+
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
113+
}
114+
}
115+
} else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
116+
// fvar path: use unit derivative trick to compute gradients
117+
T_partials_return alpha_unit = alpha_dbl;
118+
alpha_unit.d_ = 1;
119+
T_partials_return beta_unit = beta_y;
120+
beta_unit.d_ = 0;
121+
122+
if (use_cf) {
123+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
124+
T_partials_return log_Qn_fvar
125+
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
126+
dlogQ_dalpha = log_Qn_fvar.d_;
127+
} else {
128+
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
129+
log_Qn = log1m(Pn);
130+
131+
if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
132+
// Fallback to continued fraction
133+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
134+
T_partials_return log_Qn_fvar
135+
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
136+
dlogQ_dalpha = log_Qn_fvar.d_;
137+
} else {
138+
T_partials_return log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
139+
dlogQ_dalpha = log_Qn_fvar.d_;
140+
}
141+
}
142+
} else {
143+
// No alpha derivative needed (alpha is constant or double-only)
144+
if (use_cf) {
145+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
146+
} else {
147+
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
148+
log_Qn = log1m(Pn);
149+
150+
if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
151+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
152+
}
153+
}
154+
}
155+
if (!std::isfinite(value_of_rec(log_Qn))) {
156+
return ops_partials.build(negative_infinity());
157+
}
82158
P += log_Qn;
83159

84-
if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85-
const T_partials_return log_y_dbl = log(y_dbl);
86-
const T_partials_return log_beta_dbl = log(beta_dbl);
87-
const T_partials_return log_pdf
88-
= alpha_dbl * log_beta_dbl - lgamma(alpha_dbl)
89-
+ (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl;
90-
const T_partials_return common_term = exp(log_pdf - log_Qn);
160+
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
161+
const T_partials_return log_y = log(y_dbl);
162+
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
163+
164+
const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
165+
- lgamma(alpha_dbl) + alpha_minus_one
166+
- beta_y;
167+
168+
const T_partials_return hazard = exp(log_pdf - log_Qn); // f/Q
91169

92170
if constexpr (is_autodiff_v<T_y>) {
93-
// d/dy log(1-F(y)) = -f(y)/(1-F(y))
94-
partials<0>(ops_partials)[n] -= common_term;
171+
partials<0>(ops_partials)[n] -= hazard;
95172
}
96173
if constexpr (is_autodiff_v<T_inv_scale>) {
97-
// d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98-
partials<2>(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
174+
partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
99175
}
100176
}
101-
102177
if constexpr (is_autodiff_v<T_shape>) {
103-
const T_partials_return digamma_val = digamma(alpha_dbl);
104-
const T_partials_return gamma_val = tgamma(alpha_dbl);
105-
// d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106-
partials<1>(ops_partials)[n]
107-
+= grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108-
/ Qn;
178+
partials<1>(ops_partials)[n] += dlogQ_dalpha;
109179
}
110180
}
111181
return ops_partials.build(P);

0 commit comments

Comments
 (0)