Skip to content

Commit 1722ad9

Browse files
committed
update neg_binomial_2_lpmf to be more vectorized friendly
1 parent 18a9cbf commit 1722ad9

2 files changed

Lines changed: 52 additions & 68 deletions

File tree

stan/math/prim/fun/select.hpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ template <typename T_true, typename T_false,
112112
plain_type_t<T_false>>,
113113
require_stan_scalar_t<T_true>* = nullptr,
114114
require_container_t<T_false>* = nullptr>
115-
inline ReturnT select(const bool c, const T_true y_true, T_false&& y_false) {
115+
inline ReturnT select(const bool c, const T_true& y_true, T_false&& y_false) {
116116
if (c) {
117117
return apply_scalar_binary(
118118
[](auto&& y_true_inner, auto&& y_false_inner) { return y_true_inner; },
@@ -140,13 +140,13 @@ inline ReturnT select(const bool c, const T_true y_true, T_false&& y_false) {
140140
template <typename T_bool, typename T_true, typename T_false,
141141
require_eigen_array_vt<std::is_integral, T_bool>* = nullptr,
142142
require_all_stan_scalar_t<T_true, T_false>* = nullptr>
143-
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
143+
inline auto select(T_bool&& c, const T_true& y_true, const T_false& y_false) {
144144
using ret_t = return_type_t<T_true, T_false>;
145-
return c
146-
.unaryExpr([y_true, y_false](bool cond) {
145+
return make_holder([y_true, y_false](auto&& c_) {
146+
return std::forward<decltype(c_)>(c_).unaryExpr([y_true, y_false](bool cond) {
147147
return cond ? ret_t(y_true) : ret_t(y_false);
148-
})
149-
.eval();
148+
});
149+
}, std::forward<T_bool>(c));
150150
}
151151

152152
/**
@@ -164,13 +164,23 @@ inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
164164
template <typename T_bool, typename T_true, typename T_false,
165165
require_eigen_array_t<T_bool>* = nullptr,
166166
require_any_eigen_array_t<T_true, T_false>* = nullptr>
167-
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
167+
inline auto select(T_bool&& c, T_true&& y_true, T_false&& y_false) {
168168
check_consistent_sizes("select", "boolean", c, "left hand side", y_true,
169169
"right hand side", y_false);
170170
using ret_t = return_type_t<T_true, T_false>;
171-
return c.select(y_true, y_false).template cast<ret_t>().eval();
171+
if constexpr (!std::is_same_v<std::decay_t<T_true>, std::decay_t<T_false>>) {
172+
return make_holder([](auto&& c_, auto&& y_true_, auto&& y_false_) {
173+
return std::forward<decltype(c_)>(c_).select(
174+
std::forward<decltype(y_true_)>(y_true_),
175+
std::forward<decltype(y_false_)>(y_false_));
176+
}, std::forward<T_bool>(c), std::forward<T_true>(y_true), std::forward<T_false>(y_false));
177+
} else {
178+
return make_holder([](auto&& c_, auto&& y_true_, auto&& y_false_) {
179+
return std::forward<decltype(c_)>(c_).select(std::forward<decltype(y_true_)>(y_true_),
180+
std::forward<decltype(y_false_)>(y_false_)).template cast<ret_t>();
181+
}, std::forward<T_bool>(c), std::forward<T_true>(y_true), std::forward<T_false>(y_false));
182+
}
172183
}
173-
174184
} // namespace math
175185
} // namespace stan
176186

stan/math/prim/prob/neg_binomial_2_lpmf.hpp

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1212
#include <stan/math/prim/fun/size.hpp>
1313
#include <stan/math/prim/fun/size_zero.hpp>
14+
#include <stan/math/prim/fun/select.hpp>
1415
#include <stan/math/prim/fun/value_of.hpp>
1516
#include <stan/math/prim/functor/partials_propagator.hpp>
1617
#include <cmath>
@@ -51,67 +52,40 @@ inline return_type_t<T_location, T_precision> neg_binomial_2_lpmf(
5152
T_partials_return logp(0.0);
5253
auto ops_partials = make_partials_propagator(mu_ref, phi_ref);
5354

54-
scalar_seq_view<T_n_ref> n_vec(n_ref);
55-
scalar_seq_view<T_mu_ref> mu_vec(mu_ref);
56-
scalar_seq_view<T_phi_ref> phi_vec(phi_ref);
57-
size_t size_mu = stan::math::size(mu);
58-
size_t size_phi = stan::math::size(phi);
59-
size_t size_mu_phi = max_size(mu, phi);
60-
size_t size_n_phi = max_size(n, phi);
61-
size_t size_all = max_size(n, mu, phi);
62-
63-
VectorBuilder<true, T_partials_return, T_location> mu_val(size_mu);
64-
for (size_t i = 0; i < size_mu; ++i) {
65-
mu_val[i] = mu_vec.val(i);
66-
}
67-
68-
VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi);
69-
VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi);
70-
for (size_t i = 0; i < size_phi; ++i) {
71-
phi_val[i] = phi_vec.val(i);
72-
log_phi[i] = log(phi_val[i]);
73-
}
74-
75-
VectorBuilder<true, T_partials_return, T_location, T_precision> mu_plus_phi(
76-
size_mu_phi);
77-
VectorBuilder<true, T_partials_return, T_location, T_precision>
78-
log_mu_plus_phi(size_mu_phi);
79-
for (size_t i = 0; i < size_mu_phi; ++i) {
80-
mu_plus_phi[i] = mu_val[i] + phi_val[i];
81-
log_mu_plus_phi[i] = log(mu_plus_phi[i]);
55+
auto n_vec = as_array_or_scalar(n_ref);
56+
auto mu_vec = as_array_or_scalar(mu_ref);
57+
auto phi_vec = as_array_or_scalar(phi_ref);
58+
decltype(auto) mu_val = value_of(mu_vec);
59+
decltype(auto) phi_val = value_of(phi_vec);
60+
auto log_phi = log(phi_val);
61+
auto mu_plus_phi = mu_val + phi_val;
62+
auto log_mu_plus_phi = log(mu_plus_phi);
63+
auto n_plus_phi = value_of(n_vec) + phi_val;
64+
constexpr bool include_precision = include_summand<propto, T_precision>::value;
65+
constexpr bool include_location = include_summand<propto, T_location>::value;
66+
auto logp_calc = [&]() {
67+
return -phi_val * (log1p(mu_val / phi_val))
68+
- value_of(n_vec) * log_mu_plus_phi;
69+
};
70+
if constexpr (include_precision || include_location) {
71+
if constexpr (include_precision && include_location) {
72+
logp += sum(binomial_coefficient_log(n_plus_phi - 1, n_vec) + multiply_log(n_vec, mu_val) + logp_calc());
73+
} else if constexpr (include_precision) {
74+
logp += sum(binomial_coefficient_log(n_plus_phi - 1, n_vec) + logp_calc());
75+
} else if constexpr (include_location) {
76+
logp += sum(multiply_log(n_vec, mu_val) + logp_calc());
77+
}
8278
}
83-
84-
VectorBuilder<true, T_partials_return, T_n, T_precision> n_plus_phi(
85-
size_n_phi);
86-
for (size_t i = 0; i < size_n_phi; ++i) {
87-
n_plus_phi[i] = n_vec[i] + phi_val[i];
79+
if constexpr (is_autodiff_v<T_location>) {
80+
partials<0>(ops_partials) = n_vec / mu_val - (n_vec + phi_val) / mu_plus_phi;
8881
}
89-
90-
for (size_t i = 0; i < size_all; i++) {
91-
if constexpr (include_summand<propto, T_precision>::value) {
92-
logp += binomial_coefficient_log(n_plus_phi[i] - 1, n_vec[i]);
93-
}
94-
if constexpr (include_summand<propto, T_location>::value) {
95-
logp += multiply_log(n_vec[i], mu_val[i]);
96-
}
97-
logp += -phi_val[i] * (log1p(mu_val[i] / phi_val[i]))
98-
- n_vec[i] * log_mu_plus_phi[i];
99-
100-
if constexpr (is_autodiff_v<T_location>) {
101-
partials<0>(ops_partials)[i]
102-
+= n_vec[i] / mu_val[i] - (n_vec[i] + phi_val[i]) / mu_plus_phi[i];
103-
}
104-
if constexpr (is_autodiff_v<T_precision>) {
105-
T_partials_return log_term;
106-
if (mu_val[i] < phi_val[i]) {
107-
log_term = log1p(-mu_val[i] / mu_plus_phi[i]);
108-
} else {
109-
log_term = log_phi[i] - log_mu_plus_phi[i];
110-
}
111-
partials<1>(ops_partials)[i] += (mu_val[i] - n_vec[i]) / mu_plus_phi[i]
112-
+ log_term - digamma(phi_val[i])
113-
+ digamma(n_plus_phi[i]);
114-
}
82+
if constexpr (is_autodiff_v<T_precision>) {
83+
auto log_term
84+
= select(mu_val < phi_val, log1p(-mu_val / mu_plus_phi),
85+
log_phi - log_mu_plus_phi);
86+
partials<1>(ops_partials) = (mu_val - value_of(n_vec)) / mu_plus_phi
87+
+ log_term - digamma(phi_val)
88+
+ digamma(n_plus_phi);
11589
}
11690
return ops_partials.build(logp);
11791
}

0 commit comments

Comments
 (0)