66#include < stan/math/prim/fun/constants.hpp>
77#include < stan/math/prim/fun/digamma.hpp>
88#include < stan/math/prim/fun/exp.hpp>
9- #include < stan/math/prim/fun/fma.hpp>
10- #include < stan/math/prim/fun/gamma_p.hpp>
9+ #include < stan/math/prim/fun/gamma_q.hpp>
1110#include < stan/math/prim/fun/grad_reg_inc_gamma.hpp>
12- #include < stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
13- #include < stan/math/prim/fun/lgamma.hpp>
1411#include < stan/math/prim/fun/log.hpp>
15- #include < stan/math/prim/fun/log1m.hpp>
1612#include < stan/math/prim/fun/max_size.hpp>
1713#include < stan/math/prim/fun/scalar_seq_view.hpp>
1814#include < stan/math/prim/fun/size.hpp>
1915#include < stan/math/prim/fun/size_zero.hpp>
2016#include < stan/math/prim/fun/tgamma.hpp>
21- #include < stan/math/prim/fun/value_of_rec.hpp>
22- #include < stan/math/prim/fun/log_gamma_q_dgamma.hpp>
17+ #include < stan/math/prim/fun/value_of.hpp>
2318#include < stan/math/prim/functor/partials_propagator.hpp>
2419#include < cmath>
2520
@@ -29,9 +24,10 @@ namespace math {
2924template <typename T_y, typename T_shape, typename T_inv_scale>
3025inline return_type_t <T_y, T_shape, T_inv_scale> gamma_lccdf (
3126 const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
27+ using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale>;
3228 using std::exp;
3329 using std::log;
34- using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale> ;
30+ using std::pow ;
3531 using T_y_ref = ref_type_t <T_y>;
3632 using T_alpha_ref = ref_type_t <T_shape>;
3733 using T_beta_ref = ref_type_t <T_inv_scale>;
@@ -55,127 +51,61 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
5551 scalar_seq_view<T_y_ref> y_vec (y_ref);
5652 scalar_seq_view<T_alpha_ref> alpha_vec (alpha_ref);
5753 scalar_seq_view<T_beta_ref> beta_vec (beta_ref);
58- const size_t N = max_size (y, alpha, beta);
59-
60- constexpr bool any_fvar = is_fvar<scalar_type_t <T_y>>::value
61- || is_fvar<scalar_type_t <T_shape>>::value
62- || is_fvar<scalar_type_t <T_inv_scale>>::value;
63- constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
54+ size_t N = max_size (y, alpha, beta);
55+
56+ // Explicit return for extreme values
57+ // The gradients are technically ill-defined, but treated as zero
58+ for (size_t i = 0 ; i < stan::math::size (y); i++) {
59+ if (y_vec.val (i) == 0 ) {
60+ // LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61+ return ops_partials.build (0.0 );
62+ }
63+ }
6464
6565 for (size_t n = 0 ; n < N; n++) {
6666 // Explicit results for extreme values
6767 // The gradients are technically ill-defined, but treated as zero
68- const T_partials_return y_dbl = y_vec.val (n);
69- if (y_dbl == 0.0 ) {
70- continue ;
71- }
72- if (y_dbl == INFTY) {
68+ if (y_vec.val (n) == INFTY) {
69+ // LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
7370 return ops_partials.build (negative_infinity ());
7471 }
7572
73+ const T_partials_return y_dbl = y_vec.val (n);
7674 const T_partials_return alpha_dbl = alpha_vec.val (n);
7775 const T_partials_return beta_dbl = beta_vec.val (n);
76+ const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
7877
79- const T_partials_return beta_y = beta_dbl * y_dbl;
80- if (beta_y == INFTY) {
81- return ops_partials.build (negative_infinity ());
82- }
78+ // Qn = 1 - Pn
79+ const T_partials_return Qn = gamma_q (alpha_dbl, beta_y_dbl);
80+ const T_partials_return log_Qn = log (Qn);
8381
84- bool use_cf = beta_y > alpha_dbl + 1.0 ;
85- T_partials_return log_Qn;
86- [[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0 ;
87-
88- // Branch by autodiff type first, then handle use_cf logic inside each path
89- if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
90- // var-only path: use log_gamma_q_dgamma which computes both log_q
91- // and its gradient analytically with double inputs
92- const double beta_y_dbl = value_of_rec (beta_y);
93- const double alpha_dbl_val = value_of_rec (alpha_dbl);
94-
95- if (use_cf) {
96- auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
97- log_Qn = log_q_result.log_q ;
98- dlogQ_dalpha = log_q_result.dlog_q_da ;
99- } else {
100- const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
101- log_Qn = log1m (Pn);
102- const T_partials_return Qn = exp (log_Qn);
103-
104- // Check if we need to fallback to continued fraction
105- bool need_cf_fallback
106- = !std::isfinite (value_of_rec (log_Qn)) || Qn <= 0.0 ;
107- if (need_cf_fallback && beta_y > 0.0 ) {
108- auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
109- log_Qn = log_q_result.log_q ;
110- dlogQ_dalpha = log_q_result.dlog_q_da ;
111- } else {
112- dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
113- }
114- }
115- } else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
116- // fvar path: use unit derivative trick to compute gradients
117- T_partials_return alpha_unit = alpha_dbl;
118- alpha_unit.d_ = 1 ;
119- T_partials_return beta_unit = beta_y;
120- beta_unit.d_ = 0 ;
121-
122- if (use_cf) {
123- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
124- T_partials_return log_Qn_fvar
125- = internal::log_q_gamma_cf (alpha_unit, beta_unit);
126- dlogQ_dalpha = log_Qn_fvar.d_ ;
127- } else {
128- const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
129- log_Qn = log1m (Pn);
130-
131- if (!std::isfinite (value_of_rec (log_Qn)) && beta_y > 0.0 ) {
132- // Fallback to continued fraction
133- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
134- T_partials_return log_Qn_fvar
135- = internal::log_q_gamma_cf (alpha_unit, beta_unit);
136- dlogQ_dalpha = log_Qn_fvar.d_ ;
137- } else {
138- T_partials_return log_Qn_fvar = log1m (gamma_p (alpha_unit, beta_unit));
139- dlogQ_dalpha = log_Qn_fvar.d_ ;
140- }
141- }
142- } else {
143- // No alpha derivative needed (alpha is constant or double-only)
144- if (use_cf) {
145- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
146- } else {
147- const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
148- log_Qn = log1m (Pn);
149-
150- if (!std::isfinite (value_of_rec (log_Qn)) && beta_y > 0.0 ) {
151- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
152- }
153- }
154- }
155- if (!std::isfinite (value_of_rec (log_Qn))) {
156- return ops_partials.build (negative_infinity ());
157- }
15882 P += log_Qn;
15983
160- if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
161- const T_partials_return log_y = log (y_dbl);
162- const T_partials_return alpha_minus_one = fma (alpha_dbl, log_y, -log_y);
163-
164- const T_partials_return log_pdf = alpha_dbl * log (beta_dbl)
165- - lgamma (alpha_dbl) + alpha_minus_one
166- - beta_y;
167-
168- const T_partials_return hazard = exp (log_pdf - log_Qn); // f/Q
84+ if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85+ const T_partials_return log_y_dbl = log (y_dbl);
86+ const T_partials_return log_beta_dbl = log (beta_dbl);
87+ const T_partials_return log_pdf
88+ = alpha_dbl * log_beta_dbl - lgamma (alpha_dbl)
89+ + (alpha_dbl - 1.0 ) * log_y_dbl - beta_y_dbl;
90+ const T_partials_return common_term = exp (log_pdf - log_Qn);
16991
17092 if constexpr (is_autodiff_v<T_y>) {
171- partials<0 >(ops_partials)[n] -= hazard;
93+ // d/dy log(1-F(y)) = -f(y)/(1-F(y))
94+ partials<0 >(ops_partials)[n] -= common_term;
17295 }
17396 if constexpr (is_autodiff_v<T_inv_scale>) {
174- partials<2 >(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
97+ // d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98+ partials<2 >(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
17599 }
176100 }
101+
177102 if constexpr (is_autodiff_v<T_shape>) {
178- partials<1 >(ops_partials)[n] += dlogQ_dalpha;
103+ const T_partials_return digamma_val = digamma (alpha_dbl);
104+ const T_partials_return gamma_val = tgamma (alpha_dbl);
105+ // d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106+ partials<1 >(ops_partials)[n]
107+ += grad_reg_inc_gamma (alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108+ / Qn;
179109 }
180110 }
181111 return ops_partials.build (P);
0 commit comments