|
| 1 | +#ifndef STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP |
| 2 | +#define STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/meta.hpp> |
| 5 | +#include <stan/math/prim/fun/constants.hpp> |
| 6 | +#include <stan/math/prim/fun/digamma.hpp> |
| 7 | +#include <stan/math/prim/fun/exp.hpp> |
| 8 | +#include <stan/math/prim/fun/fabs.hpp> |
| 9 | +#include <stan/math/prim/fun/gamma_p.hpp> |
| 10 | +#include <stan/math/prim/fun/gamma_q.hpp> |
| 11 | +#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp> |
| 12 | +#include <stan/math/prim/fun/inv.hpp> |
| 13 | +#include <stan/math/prim/fun/lgamma.hpp> |
| 14 | +#include <stan/math/prim/fun/log.hpp> |
| 15 | +#include <stan/math/prim/fun/log1m.hpp> |
| 16 | +#include <stan/math/prim/fun/tgamma.hpp> |
| 17 | +#include <stan/math/prim/fun/value_of.hpp> |
| 18 | +#include <stan/math/prim/fun/value_of_rec.hpp> |
| 19 | +#include <cmath> |
| 20 | + |
| 21 | +namespace stan { |
| 22 | +namespace math { |
| 23 | + |
| 24 | +namespace internal { |
| 25 | + |
| 26 | +constexpr double LOG_Q_GAMMA_CF_PRECISION = 1.49012e-12; |
| 27 | + |
| 28 | +/** |
| 29 | + * Compute log(Q(a,z)) using continued fraction expansion for upper incomplete |
| 30 | + * gamma function. |
| 31 | + * |
| 32 | + * @tparam T_a Type of shape parameter a (double or fvar types) |
| 33 | + * @tparam T_z Type of value parameter z (double or fvar types) |
| 34 | + * @param a Shape parameter |
| 35 | + * @param z Value at which to evaluate |
| 36 | + * @param precision Convergence threshold, default of sqrt(machine_epsilon) |
| 37 | + * @param max_steps Maximum number of continued fraction iterations |
| 38 | + * @return log(Q(a,z)) with the return type of T_a and T_z |
| 39 | + */ |
| 40 | +template <typename T_a, typename T_z> |
| 41 | +inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z, |
| 42 | + double precision |
| 43 | + = LOG_Q_GAMMA_CF_PRECISION, |
| 44 | + int max_steps = 250) { |
| 45 | + using T_return = return_type_t<T_a, T_z>; |
| 46 | + const T_return log_prefactor = a * log(z) - z - lgamma(a); |
| 47 | + |
| 48 | + T_return b_init = z + 1.0 - a; |
| 49 | + T_return C = (fabs(value_of_rec(b_init)) >= EPSILON) |
| 50 | + ? b_init |
| 51 | + : std::decay_t<decltype(b_init)>(EPSILON); |
| 52 | + T_return D = 0.0; |
| 53 | + T_return f = C; |
| 54 | + for (int i = 1; i <= max_steps; ++i) { |
| 55 | + T_a an = -i * (i - a); |
| 56 | + const T_return b = b_init + 2.0 * i; |
| 57 | + D = b + an * D; |
| 58 | + D = (fabs(value_of_rec(D)) >= EPSILON) ? D |
| 59 | + : std::decay_t<decltype(D)>(EPSILON); |
| 60 | + C = b + an / C; |
| 61 | + C = (fabs(value_of_rec(C)) >= EPSILON) ? C |
| 62 | + : std::decay_t<decltype(C)>(EPSILON); |
| 63 | + D = inv(D); |
| 64 | + const T_return delta = C * D; |
| 65 | + f *= delta; |
| 66 | + const double delta_m1 = fabs(value_of_rec(delta) - 1.0); |
| 67 | + if (delta_m1 < precision) { |
| 68 | + break; |
| 69 | + } |
| 70 | + } |
| 71 | + return log_prefactor - log(f); |
| 72 | +} |
| 73 | + |
| 74 | +} // namespace internal |
| 75 | + |
| 76 | +/** |
| 77 | + * Compute log(Q(a,z)) and its gradient with respect to a using continued |
| 78 | + * fraction expansion, where Q(a,z) = Gamma(a,z) / Gamma(a) is the regularized |
| 79 | + * upper incomplete gamma function. |
| 80 | + * |
| 81 | + * This uses a continued fraction representation for numerical stability when |
| 82 | + * computing the upper incomplete gamma function in log space, along with |
| 83 | + * analytical gradient computation. |
| 84 | + * |
| 85 | + * @tparam T_a type of the shape parameter |
| 86 | + * @tparam T_z type of the value parameter |
| 87 | + * @param a shape parameter (must be positive) |
| 88 | + * @param z value parameter (must be non-negative) |
| 89 | + * @param precision convergence threshold, default of sqrt(machine_epsilon) |
| 90 | + * @param max_steps maximum iterations for continued fraction |
| 91 | + * @return structure containing log(Q(a,z)) and d/da log(Q(a,z)) |
| 92 | + */ |
| 93 | +template <typename T_a, typename T_z> |
| 94 | +inline std::pair<return_type_t<T_a, T_z>, return_type_t<T_a, T_z>> |
| 95 | +log_gamma_q_dgamma(const T_a& a, const T_z& z, |
| 96 | + double precision = internal::LOG_Q_GAMMA_CF_PRECISION, |
| 97 | + int max_steps = 250) { |
| 98 | + using T_return = return_type_t<T_a, T_z>; |
| 99 | + const double a_val = value_of(a); |
| 100 | + const double z_val = value_of(z); |
| 101 | + // For z > a + 1, use continued fraction for better numerical stability |
| 102 | + if (z_val > a_val + 1.0) { |
| 103 | + std::pair<T_return, T_return> result{ |
| 104 | + internal::log_q_gamma_cf(a_val, z_val, precision, max_steps), 0.0}; |
| 105 | + // For gradient, use: d/da log(Q) = (1/Q) * dQ/da |
| 106 | + // grad_reg_inc_gamma computes dQ/da |
| 107 | + const T_return Q_val = exp(result.first); |
| 108 | + const double dQ_da |
| 109 | + = grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val)); |
| 110 | + result.second = dQ_da / Q_val; |
| 111 | + return result; |
| 112 | + } else { |
| 113 | + // For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy |
| 114 | + const double P_val = gamma_p(a_val, z_val); |
| 115 | + std::pair<T_return, T_return> result{log1m(P_val), 0.0}; |
| 116 | + // Gradient: d/da log(Q) = (1/Q) * dQ/da |
| 117 | + // grad_reg_inc_gamma computes dQ/da |
| 118 | + const T_return Q_val = exp(result.first); |
| 119 | + if (Q_val > 0) { |
| 120 | + const double dQ_da |
| 121 | + = grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val)); |
| 122 | + result.second = dQ_da / Q_val; |
| 123 | + } else { |
| 124 | + // Fallback if Q rounds to zero - use asymptotic approximation |
| 125 | + result.second = log(z_val) - digamma(a_val); |
| 126 | + } |
| 127 | + return result; |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +} // namespace math |
| 132 | +} // namespace stan |
| 133 | + |
| 134 | +#endif |
0 commit comments