Skip to content

Commit 42063be

Browse files
authored
Merge pull request #3272 from stan-dev/revert-3266-fix-gamma-lccdf-v2
2 parents bf9fc6b + d7f7434 commit 42063be

File tree

4 files changed

+39
-430
lines changed

4 files changed

+39
-430
lines changed

stan/math/prim/fun/log_gamma_q_dgamma.hpp

Lines changed: 0 additions & 156 deletions
This file was deleted.

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 39 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,15 @@
66
#include <stan/math/prim/fun/constants.hpp>
77
#include <stan/math/prim/fun/digamma.hpp>
88
#include <stan/math/prim/fun/exp.hpp>
9-
#include <stan/math/prim/fun/fma.hpp>
10-
#include <stan/math/prim/fun/gamma_p.hpp>
9+
#include <stan/math/prim/fun/gamma_q.hpp>
1110
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
12-
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
13-
#include <stan/math/prim/fun/lgamma.hpp>
1411
#include <stan/math/prim/fun/log.hpp>
15-
#include <stan/math/prim/fun/log1m.hpp>
1612
#include <stan/math/prim/fun/max_size.hpp>
1713
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1814
#include <stan/math/prim/fun/size.hpp>
1915
#include <stan/math/prim/fun/size_zero.hpp>
2016
#include <stan/math/prim/fun/tgamma.hpp>
21-
#include <stan/math/prim/fun/value_of_rec.hpp>
22-
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
17+
#include <stan/math/prim/fun/value_of.hpp>
2318
#include <stan/math/prim/functor/partials_propagator.hpp>
2419
#include <cmath>
2520

@@ -29,9 +24,10 @@ namespace math {
2924
template <typename T_y, typename T_shape, typename T_inv_scale>
3025
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
3126
const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
27+
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
3228
using std::exp;
3329
using std::log;
34-
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
30+
using std::pow;
3531
using T_y_ref = ref_type_t<T_y>;
3632
using T_alpha_ref = ref_type_t<T_shape>;
3733
using T_beta_ref = ref_type_t<T_inv_scale>;
@@ -55,127 +51,61 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
5551
scalar_seq_view<T_y_ref> y_vec(y_ref);
5652
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
5753
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
58-
const size_t N = max_size(y, alpha, beta);
59-
60-
constexpr bool any_fvar = is_fvar<scalar_type_t<T_y>>::value
61-
|| is_fvar<scalar_type_t<T_shape>>::value
62-
|| is_fvar<scalar_type_t<T_inv_scale>>::value;
63-
constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
54+
size_t N = max_size(y, alpha, beta);
55+
56+
// Explicit return for extreme values
57+
// The gradients are technically ill-defined, but treated as zero
58+
for (size_t i = 0; i < stan::math::size(y); i++) {
59+
if (y_vec.val(i) == 0) {
60+
// LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61+
return ops_partials.build(0.0);
62+
}
63+
}
6464

6565
for (size_t n = 0; n < N; n++) {
6666
// Explicit results for extreme values
6767
// The gradients are technically ill-defined, but treated as zero
68-
const T_partials_return y_dbl = y_vec.val(n);
69-
if (y_dbl == 0.0) {
70-
continue;
71-
}
72-
if (y_dbl == INFTY) {
68+
if (y_vec.val(n) == INFTY) {
69+
// LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
7370
return ops_partials.build(negative_infinity());
7471
}
7572

73+
const T_partials_return y_dbl = y_vec.val(n);
7674
const T_partials_return alpha_dbl = alpha_vec.val(n);
7775
const T_partials_return beta_dbl = beta_vec.val(n);
76+
const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
7877

79-
const T_partials_return beta_y = beta_dbl * y_dbl;
80-
if (beta_y == INFTY) {
81-
return ops_partials.build(negative_infinity());
82-
}
78+
// Qn = 1 - Pn
79+
const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl);
80+
const T_partials_return log_Qn = log(Qn);
8381

84-
bool use_cf = beta_y > alpha_dbl + 1.0;
85-
T_partials_return log_Qn;
86-
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;
87-
88-
// Branch by autodiff type first, then handle use_cf logic inside each path
89-
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
90-
// var-only path: use log_gamma_q_dgamma which computes both log_q
91-
// and its gradient analytically with double inputs
92-
const double beta_y_dbl = value_of_rec(beta_y);
93-
const double alpha_dbl_val = value_of_rec(alpha_dbl);
94-
95-
if (use_cf) {
96-
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
97-
log_Qn = log_q_result.log_q;
98-
dlogQ_dalpha = log_q_result.dlog_q_da;
99-
} else {
100-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
101-
log_Qn = log1m(Pn);
102-
const T_partials_return Qn = exp(log_Qn);
103-
104-
// Check if we need to fallback to continued fraction
105-
bool need_cf_fallback
106-
= !std::isfinite(value_of_rec(log_Qn)) || Qn <= 0.0;
107-
if (need_cf_fallback && beta_y > 0.0) {
108-
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
109-
log_Qn = log_q_result.log_q;
110-
dlogQ_dalpha = log_q_result.dlog_q_da;
111-
} else {
112-
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
113-
}
114-
}
115-
} else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
116-
// fvar path: use unit derivative trick to compute gradients
117-
T_partials_return alpha_unit = alpha_dbl;
118-
alpha_unit.d_ = 1;
119-
T_partials_return beta_unit = beta_y;
120-
beta_unit.d_ = 0;
121-
122-
if (use_cf) {
123-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
124-
T_partials_return log_Qn_fvar
125-
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
126-
dlogQ_dalpha = log_Qn_fvar.d_;
127-
} else {
128-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
129-
log_Qn = log1m(Pn);
130-
131-
if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
132-
// Fallback to continued fraction
133-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
134-
T_partials_return log_Qn_fvar
135-
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
136-
dlogQ_dalpha = log_Qn_fvar.d_;
137-
} else {
138-
T_partials_return log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
139-
dlogQ_dalpha = log_Qn_fvar.d_;
140-
}
141-
}
142-
} else {
143-
// No alpha derivative needed (alpha is constant or double-only)
144-
if (use_cf) {
145-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
146-
} else {
147-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
148-
log_Qn = log1m(Pn);
149-
150-
if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
151-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
152-
}
153-
}
154-
}
155-
if (!std::isfinite(value_of_rec(log_Qn))) {
156-
return ops_partials.build(negative_infinity());
157-
}
15882
P += log_Qn;
15983

160-
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
161-
const T_partials_return log_y = log(y_dbl);
162-
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
163-
164-
const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
165-
- lgamma(alpha_dbl) + alpha_minus_one
166-
- beta_y;
167-
168-
const T_partials_return hazard = exp(log_pdf - log_Qn); // f/Q
84+
if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85+
const T_partials_return log_y_dbl = log(y_dbl);
86+
const T_partials_return log_beta_dbl = log(beta_dbl);
87+
const T_partials_return log_pdf
88+
= alpha_dbl * log_beta_dbl - lgamma(alpha_dbl)
89+
+ (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl;
90+
const T_partials_return common_term = exp(log_pdf - log_Qn);
16991

17092
if constexpr (is_autodiff_v<T_y>) {
171-
partials<0>(ops_partials)[n] -= hazard;
93+
// d/dy log(1-F(y)) = -f(y)/(1-F(y))
94+
partials<0>(ops_partials)[n] -= common_term;
17295
}
17396
if constexpr (is_autodiff_v<T_inv_scale>) {
174-
partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
97+
// d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98+
partials<2>(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
17599
}
176100
}
101+
177102
if constexpr (is_autodiff_v<T_shape>) {
178-
partials<1>(ops_partials)[n] += dlogQ_dalpha;
103+
const T_partials_return digamma_val = digamma(alpha_dbl);
104+
const T_partials_return gamma_val = tgamma(alpha_dbl);
105+
// d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106+
partials<1>(ops_partials)[n]
107+
+= grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108+
/ Qn;
179109
}
180110
}
181111
return ops_partials.build(P);

0 commit comments

Comments
 (0)