|
11 | 11 | #include <stan/math/prim/fun/scalar_seq_view.hpp> |
12 | 12 | #include <stan/math/prim/fun/size.hpp> |
13 | 13 | #include <stan/math/prim/fun/size_zero.hpp> |
| 14 | +#include <stan/math/prim/fun/select.hpp> |
14 | 15 | #include <stan/math/prim/fun/value_of.hpp> |
15 | 16 | #include <stan/math/prim/functor/partials_propagator.hpp> |
16 | 17 | #include <cmath> |
@@ -51,67 +52,40 @@ inline return_type_t<T_location, T_precision> neg_binomial_2_lpmf( |
51 | 52 | T_partials_return logp(0.0); |
52 | 53 | auto ops_partials = make_partials_propagator(mu_ref, phi_ref); |
53 | 54 |
|
54 | | - scalar_seq_view<T_n_ref> n_vec(n_ref); |
55 | | - scalar_seq_view<T_mu_ref> mu_vec(mu_ref); |
56 | | - scalar_seq_view<T_phi_ref> phi_vec(phi_ref); |
57 | | - size_t size_mu = stan::math::size(mu); |
58 | | - size_t size_phi = stan::math::size(phi); |
59 | | - size_t size_mu_phi = max_size(mu, phi); |
60 | | - size_t size_n_phi = max_size(n, phi); |
61 | | - size_t size_all = max_size(n, mu, phi); |
62 | | - |
63 | | - VectorBuilder<true, T_partials_return, T_location> mu_val(size_mu); |
64 | | - for (size_t i = 0; i < size_mu; ++i) { |
65 | | - mu_val[i] = mu_vec.val(i); |
66 | | - } |
67 | | - |
68 | | - VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi); |
69 | | - VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi); |
70 | | - for (size_t i = 0; i < size_phi; ++i) { |
71 | | - phi_val[i] = phi_vec.val(i); |
72 | | - log_phi[i] = log(phi_val[i]); |
73 | | - } |
74 | | - |
75 | | - VectorBuilder<true, T_partials_return, T_location, T_precision> mu_plus_phi( |
76 | | - size_mu_phi); |
77 | | - VectorBuilder<true, T_partials_return, T_location, T_precision> |
78 | | - log_mu_plus_phi(size_mu_phi); |
79 | | - for (size_t i = 0; i < size_mu_phi; ++i) { |
80 | | - mu_plus_phi[i] = mu_val[i] + phi_val[i]; |
81 | | - log_mu_plus_phi[i] = log(mu_plus_phi[i]); |
| 55 | + auto n_vec = as_array_or_scalar(n_ref); |
| 56 | + auto mu_vec = as_array_or_scalar(mu_ref); |
| 57 | + auto phi_vec = as_array_or_scalar(phi_ref); |
| 58 | + decltype(auto) mu_val = value_of(mu_vec); |
| 59 | + decltype(auto) phi_val = value_of(phi_vec); |
| 60 | + auto log_phi = log(phi_val); |
| 61 | + auto mu_plus_phi = mu_val + phi_val; |
| 62 | + auto log_mu_plus_phi = log(mu_plus_phi); |
| 63 | + auto n_plus_phi = value_of(n_vec) + phi_val; |
| 64 | + constexpr bool include_precision = include_summand<propto, T_precision>::value; |
| 65 | + constexpr bool include_location = include_summand<propto, T_location>::value; |
| 66 | + auto logp_calc = [&]() { |
| 67 | + return -phi_val * (log1p(mu_val / phi_val)) |
| 68 | + - value_of(n_vec) * log_mu_plus_phi; |
| 69 | + }; |
| 70 | + if constexpr (include_precision || include_location) { |
| 71 | + if constexpr (include_precision && include_location) { |
| 72 | + logp += sum(binomial_coefficient_log(n_plus_phi - 1, n_vec) + multiply_log(n_vec, mu_val) + logp_calc()); |
| 73 | + } else if constexpr (include_precision) { |
| 74 | + logp += sum(binomial_coefficient_log(n_plus_phi - 1, n_vec) + logp_calc()); |
| 75 | + } else if constexpr (include_location) { |
| 76 | + logp += sum(multiply_log(n_vec, mu_val) + logp_calc()); |
| 77 | + } |
82 | 78 | } |
83 | | - |
84 | | - VectorBuilder<true, T_partials_return, T_n, T_precision> n_plus_phi( |
85 | | - size_n_phi); |
86 | | - for (size_t i = 0; i < size_n_phi; ++i) { |
87 | | - n_plus_phi[i] = n_vec[i] + phi_val[i]; |
| 79 | + if constexpr (is_autodiff_v<T_location>) { |
| 80 | + partials<0>(ops_partials) = n_vec / mu_val - (n_vec + phi_val) / mu_plus_phi; |
88 | 81 | } |
89 | | - |
90 | | - for (size_t i = 0; i < size_all; i++) { |
91 | | - if constexpr (include_summand<propto, T_precision>::value) { |
92 | | - logp += binomial_coefficient_log(n_plus_phi[i] - 1, n_vec[i]); |
93 | | - } |
94 | | - if constexpr (include_summand<propto, T_location>::value) { |
95 | | - logp += multiply_log(n_vec[i], mu_val[i]); |
96 | | - } |
97 | | - logp += -phi_val[i] * (log1p(mu_val[i] / phi_val[i])) |
98 | | - - n_vec[i] * log_mu_plus_phi[i]; |
99 | | - |
100 | | - if constexpr (is_autodiff_v<T_location>) { |
101 | | - partials<0>(ops_partials)[i] |
102 | | - += n_vec[i] / mu_val[i] - (n_vec[i] + phi_val[i]) / mu_plus_phi[i]; |
103 | | - } |
104 | | - if constexpr (is_autodiff_v<T_precision>) { |
105 | | - T_partials_return log_term; |
106 | | - if (mu_val[i] < phi_val[i]) { |
107 | | - log_term = log1p(-mu_val[i] / mu_plus_phi[i]); |
108 | | - } else { |
109 | | - log_term = log_phi[i] - log_mu_plus_phi[i]; |
110 | | - } |
111 | | - partials<1>(ops_partials)[i] += (mu_val[i] - n_vec[i]) / mu_plus_phi[i] |
112 | | - + log_term - digamma(phi_val[i]) |
113 | | - + digamma(n_plus_phi[i]); |
114 | | - } |
| 82 | + if constexpr (is_autodiff_v<T_precision>) { |
| 83 | + auto log_term |
| 84 | + = select(mu_val < phi_val, log1p(-mu_val / mu_plus_phi), |
| 85 | + log_phi - log_mu_plus_phi); |
| 86 | + partials<1>(ops_partials) = (mu_val - value_of(n_vec)) / mu_plus_phi |
| 87 | + + log_term - digamma(phi_val) |
| 88 | + + digamma(n_plus_phi); |
115 | 89 | } |
116 | 90 | return ops_partials.build(logp); |
117 | 91 | } |
|
0 commit comments