Skip to content

Commit 6e1eb53

Browse files
committed
fix issue
1 parent 405322b commit 6e1eb53

1 file changed

Lines changed: 16 additions & 17 deletions

File tree

stan/math/prim/prob/yule_simon_rng.hpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_PRIM_PROB_YULE_SIMON_RNG_HPP
22
#define STAN_MATH_PRIM_PROB_YULE_SIMON_RNG_HPP
33

4+
#include <utility>
45
#include <stan/math/prim/meta.hpp>
56
#include <stan/math/prim/fun/exp.hpp>
67
#include <stan/math/prim/fun/log.hpp>
@@ -26,26 +27,24 @@ namespace math {
2627
* @throw std::domain_error if alpha is nonpositive
2728
*/
2829
template <typename T_alpha, typename RNG>
29-
inline auto yule_simon_rng(const T_alpha &alpha, RNG &rng) {
30-
using T_alpha_ref = ref_type_t<T_alpha>;
31-
static constexpr const char *function = "yule_simon_rng";
32-
33-
T_alpha_ref alpha_ref = alpha;
30+
inline auto yule_simon_rng(T_alpha&& alpha, RNG& rng) {
31+
static constexpr const char* function = "yule_simon_rng";
32+
decltype(auto) alpha_ref = to_ref(std::forward<T_alpha>(alpha));
3433
check_positive_finite(function, "Shape parameter", alpha_ref);
3534

36-
auto w = exponential_rng(alpha_ref, rng);
37-
scalar_seq_view<decltype(w)> w_vec(w);
38-
39-
size_t size_w = stan::math::size(w);
40-
VectorBuilder<true, int, T_alpha> output(size_w);
41-
for (size_t n = 0; n < size_w; ++n) {
42-
const double p = stan::math::exp(-w_vec[n]);
43-
const double odds_ratio_p
44-
= stan::math::exp(stan::math::log(p) - stan::math::log1m(p));
45-
output[n] = neg_binomial_rng(1.0, odds_ratio_p, rng) + 1;
35+
auto w = exponential_rng(std::forward<decltype(alpha_ref)>(alpha_ref), rng);
36+
auto w_arr = as_array_or_scalar(w);
37+
const auto p = stan::math::exp(-w_arr);
38+
const auto odds_ratio_p
39+
= stan::math::exp(stan::math::log(p) - stan::math::log1m(p));
40+
41+
if constexpr (is_stan_scalar_v<T_alpha>) {
42+
return neg_binomial_rng(1.0, odds_ratio_p, rng) + 1;
43+
} else {
44+
return to_array_1d(
45+
as_array_or_scalar(neg_binomial_rng(1.0, std::move(odds_ratio_p), rng))
46+
+ 1);
4647
}
47-
48-
return output.data();
4948
}
5049

5150
} // namespace math

0 commit comments

Comments
 (0)