66#include < stan/math/prim/fun/constants.hpp>
77#include < stan/math/prim/fun/digamma.hpp>
88#include < stan/math/prim/fun/exp.hpp>
9- #include < stan/math/prim/fun/gamma_q.hpp>
9+ #include < stan/math/prim/fun/fma.hpp>
10+ #include < stan/math/prim/fun/gamma_p.hpp>
1011#include < stan/math/prim/fun/grad_reg_inc_gamma.hpp>
12+ #include < stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
13+ #include < stan/math/prim/fun/lgamma.hpp>
1114#include < stan/math/prim/fun/log.hpp>
15+ #include < stan/math/prim/fun/log1m.hpp>
1216#include < stan/math/prim/fun/max_size.hpp>
1317#include < stan/math/prim/fun/scalar_seq_view.hpp>
1418#include < stan/math/prim/fun/size.hpp>
1519#include < stan/math/prim/fun/size_zero.hpp>
1620#include < stan/math/prim/fun/tgamma.hpp>
17- #include < stan/math/prim/fun/value_of.hpp>
21+ #include < stan/math/prim/fun/value_of_rec.hpp>
22+ #include < stan/math/prim/fun/log_gamma_q_dgamma.hpp>
1823#include < stan/math/prim/functor/partials_propagator.hpp>
1924#include < cmath>
2025
@@ -24,10 +29,9 @@ namespace math {
2429template <typename T_y, typename T_shape, typename T_inv_scale>
2530inline return_type_t <T_y, T_shape, T_inv_scale> gamma_lccdf (
2631 const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
27- using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale>;
2832 using std::exp;
2933 using std::log;
30- using std::pow ;
34+ using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale> ;
3135 using T_y_ref = ref_type_t <T_y>;
3236 using T_alpha_ref = ref_type_t <T_shape>;
3337 using T_beta_ref = ref_type_t <T_inv_scale>;
@@ -51,61 +55,127 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
5155 scalar_seq_view<T_y_ref> y_vec (y_ref);
5256 scalar_seq_view<T_alpha_ref> alpha_vec (alpha_ref);
5357 scalar_seq_view<T_beta_ref> beta_vec (beta_ref);
54- size_t N = max_size (y, alpha, beta);
55-
56- // Explicit return for extreme values
57- // The gradients are technically ill-defined, but treated as zero
58- for (size_t i = 0 ; i < stan::math::size (y); i++) {
59- if (y_vec.val (i) == 0 ) {
60- // LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61- return ops_partials.build (0.0 );
62- }
63- }
58+ const size_t N = max_size (y, alpha, beta);
59+
60+ constexpr bool any_fvar = is_fvar<scalar_type_t <T_y>>::value
61+ || is_fvar<scalar_type_t <T_shape>>::value
62+ || is_fvar<scalar_type_t <T_inv_scale>>::value;
63+ constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
6464
6565 for (size_t n = 0 ; n < N; n++) {
6666 // Explicit results for extreme values
6767 // The gradients are technically ill-defined, but treated as zero
68- if (y_vec.val (n) == INFTY) {
69- // LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
68+ const T_partials_return y_dbl = y_vec.val (n);
69+ if (y_dbl == 0.0 ) {
70+ continue ;
71+ }
72+ if (y_dbl == INFTY) {
7073 return ops_partials.build (negative_infinity ());
7174 }
7275
73- const T_partials_return y_dbl = y_vec.val (n);
7476 const T_partials_return alpha_dbl = alpha_vec.val (n);
7577 const T_partials_return beta_dbl = beta_vec.val (n);
76- const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
7778
78- // Qn = 1 - Pn
79- const T_partials_return Qn = gamma_q (alpha_dbl, beta_y_dbl);
80- const T_partials_return log_Qn = log (Qn);
79+ const T_partials_return beta_y = beta_dbl * y_dbl;
80+ if (beta_y == INFTY) {
81+ return ops_partials.build (negative_infinity ());
82+ }
8183
84+ bool use_cf = beta_y > alpha_dbl + 1.0 ;
85+ T_partials_return log_Qn;
86+ [[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0 ;
87+
88+ // Branch by autodiff type first, then handle use_cf logic inside each path
89+ if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
90+ // var-only path: use log_gamma_q_dgamma which computes both log_q
91+ // and its gradient analytically with double inputs
92+ const double beta_y_dbl = value_of_rec (beta_y);
93+ const double alpha_dbl_val = value_of_rec (alpha_dbl);
94+
95+ if (use_cf) {
96+ auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
97+ log_Qn = log_q_result.log_q ;
98+ dlogQ_dalpha = log_q_result.dlog_q_da ;
99+ } else {
100+ const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
101+ log_Qn = log1m (Pn);
102+ const T_partials_return Qn = exp (log_Qn);
103+
104+ // Check if we need to fallback to continued fraction
105+ bool need_cf_fallback
106+ = !std::isfinite (value_of_rec (log_Qn)) || Qn <= 0.0 ;
107+ if (need_cf_fallback && beta_y > 0.0 ) {
108+ auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
109+ log_Qn = log_q_result.log_q ;
110+ dlogQ_dalpha = log_q_result.dlog_q_da ;
111+ } else {
112+ dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
113+ }
114+ }
115+ } else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
116+ // fvar path: use unit derivative trick to compute gradients
117+ T_partials_return alpha_unit = alpha_dbl;
118+ alpha_unit.d_ = 1 ;
119+ T_partials_return beta_unit = beta_y;
120+ beta_unit.d_ = 0 ;
121+
122+ if (use_cf) {
123+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
124+ T_partials_return log_Qn_fvar
125+ = internal::log_q_gamma_cf (alpha_unit, beta_unit);
126+ dlogQ_dalpha = log_Qn_fvar.d_ ;
127+ } else {
128+ const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
129+ log_Qn = log1m (Pn);
130+
131+ if (!std::isfinite (value_of_rec (log_Qn)) && beta_y > 0.0 ) {
132+ // Fallback to continued fraction
133+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
134+ T_partials_return log_Qn_fvar
135+ = internal::log_q_gamma_cf (alpha_unit, beta_unit);
136+ dlogQ_dalpha = log_Qn_fvar.d_ ;
137+ } else {
138+ T_partials_return log_Qn_fvar = log1m (gamma_p (alpha_unit, beta_unit));
139+ dlogQ_dalpha = log_Qn_fvar.d_ ;
140+ }
141+ }
142+ } else {
143+ // No alpha derivative needed (alpha is constant or double-only)
144+ if (use_cf) {
145+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
146+ } else {
147+ const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
148+ log_Qn = log1m (Pn);
149+
150+ if (!std::isfinite (value_of_rec (log_Qn)) && beta_y > 0.0 ) {
151+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
152+ }
153+ }
154+ }
155+ if (!std::isfinite (value_of_rec (log_Qn))) {
156+ return ops_partials.build (negative_infinity ());
157+ }
82158 P += log_Qn;
83159
84- if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85- const T_partials_return log_y_dbl = log (y_dbl);
86- const T_partials_return log_beta_dbl = log (beta_dbl);
87- const T_partials_return log_pdf
88- = alpha_dbl * log_beta_dbl - lgamma (alpha_dbl)
89- + (alpha_dbl - 1.0 ) * log_y_dbl - beta_y_dbl;
90- const T_partials_return common_term = exp (log_pdf - log_Qn);
160+ if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
161+ const T_partials_return log_y = log (y_dbl);
162+ const T_partials_return alpha_minus_one = fma (alpha_dbl, log_y, -log_y);
163+
164+ const T_partials_return log_pdf = alpha_dbl * log (beta_dbl)
165+ - lgamma (alpha_dbl) + alpha_minus_one
166+ - beta_y;
167+
168+ const T_partials_return hazard = exp (log_pdf - log_Qn); // f/Q
91169
92170 if constexpr (is_autodiff_v<T_y>) {
93- // d/dy log(1-F(y)) = -f(y)/(1-F(y))
94- partials<0 >(ops_partials)[n] -= common_term;
171+ partials<0 >(ops_partials)[n] -= hazard;
95172 }
96173 if constexpr (is_autodiff_v<T_inv_scale>) {
97- // d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98- partials<2 >(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
174+ partials<2 >(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
99175 }
100176 }
101-
102177 if constexpr (is_autodiff_v<T_shape>) {
103- const T_partials_return digamma_val = digamma (alpha_dbl);
104- const T_partials_return gamma_val = tgamma (alpha_dbl);
105- // d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106- partials<1 >(ops_partials)[n]
107- += grad_reg_inc_gamma (alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108- / Qn;
178+ partials<1 >(ops_partials)[n] += dlogQ_dalpha;
109179 }
110180 }
111181 return ops_partials.build (P);
0 commit comments