Skip to content

Commit 36c77a8

Browse files
committed
fix
1 parent be7289e commit 36c77a8

4 files changed

Lines changed: 22 additions & 26 deletions

File tree

stan/math/prim/prob/multinomial_logit_log.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ namespace math {
1212
/** \ingroup multivar_dists
1313
* @deprecated use <code>multinomial_logit_lpmf</code>
1414
*/
15-
template <bool propto, typename T_prob>
16-
return_type_t<T_prob> multinomial_logit_log(
17-
const std::vector<int>& ns,
18-
const Eigen::Matrix<T_prob, Eigen::Dynamic, 1>& theta) {
19-
return multinomial_logit_lpmf<propto, T_prob>(ns, theta);
15+
template <bool propto, typename T_beta, typename T_prob = scalar_type_t<T_beta>,
16+
require_eigen_col_vector_t<T_beta>* = nullptr>
17+
return_type_t<T_prob> multinomial_logit_log(const std::vector<int>& ns,
18+
const T_beta& beta) {
19+
return multinomial_logit_lpmf<propto, T_beta>(ns, beta);
2020
}
2121

2222
/** \ingroup multivar_dists
2323
* @deprecated use <code>multinomial_logit_lpmf</code>
2424
*/
25-
template <typename T_prob>
26-
return_type_t<T_prob> multinomial_logit_log(
27-
const std::vector<int>& ns,
28-
const Eigen::Matrix<T_prob, Eigen::Dynamic, 1>& theta) {
29-
return multinomial_logit_lpmf<false>(ns, theta);
25+
template <typename T_beta, typename T_prob = scalar_type_t<T_beta>,
26+
require_eigen_col_vector_t<T_beta>* = nullptr>
27+
return_type_t<T_prob> multinomial_logit_log(const std::vector<int>& ns,
28+
const T_beta& beta) {
29+
return multinomial_logit_lpmf<false>(ns, beta);
3030
}
3131

3232
} // namespace math

stan/math/prim/prob/multinomial_logit_lpmf.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ return_type_t<T_prob> multinomial_logit_lpmf(const std::vector<int>& ns,
4848
return lp;
4949
}
5050

51-
template <typename T_prob>
52-
return_type_t<T_prob> multinomial_logit_lpmf(
53-
const std::vector<int>& ns,
54-
const Eigen::Matrix<T_prob, Eigen::Dynamic, 1>& beta) {
51+
template <typename T_beta, typename T_prob = scalar_type_t<T_beta>,
52+
require_eigen_col_vector_t<T_beta>* = nullptr>
53+
return_type_t<T_prob> multinomial_logit_lpmf(const std::vector<int>& ns,
54+
const T_beta& beta) {
5555
return multinomial_logit_lpmf<false>(ns, beta);
5656
}
5757

test/unit/math/prim/prob/multinomial_logit_log_test.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,4 @@ TEST(ProbMultinomialLogit, log_matches_lpmf) {
1818
(multinomial_logit_log<true>(ns, theta)));
1919
EXPECT_FLOAT_EQ((multinomial_logit_lpmf<false>(ns, theta)),
2020
(multinomial_logit_log<false>(ns, theta)));
21-
EXPECT_FLOAT_EQ((multinomial_logit_lpmf<true, double>(ns, theta)),
22-
(multinomial_logit_log<true, double>(ns, theta)));
23-
EXPECT_FLOAT_EQ((multinomial_logit_lpmf<false, double>(ns, theta)),
24-
(multinomial_logit_log<false, double>(ns, theta)));
25-
EXPECT_FLOAT_EQ((multinomial_logit_lpmf<double>(ns, theta)),
26-
(multinomial_logit_log<double>(ns, theta)));
2721
}

test/unit/math/rev/prob/multinomial_logit_test.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
#include <string>
66
#include <vector>
77

8+
using stan::math::multinomial_logit_lpmf;
9+
810
template <typename T_prob>
911
void expect_propto(std::vector<int>& ns1, T_prob beta1, std::vector<int>& ns2,
1012
T_prob beta2, std::string message) {
11-
expect_eq_diffs(stan::math::multinomial_logit_lpmf<false>(ns1, beta1),
12-
stan::math::multinomial_logit_lpmf<false>(ns2, beta2),
13-
stan::math::multinomial_logit_lpmf<true>(ns1, beta1),
14-
stan::math::multinomial_logit_lpmf<true>(ns2, beta2), message);
13+
expect_eq_diffs(multinomial_logit_lpmf<false>(ns1, beta1),
14+
multinomial_logit_lpmf<false>(ns2, beta2),
15+
multinomial_logit_lpmf<true>(ns1, beta1),
16+
multinomial_logit_lpmf<true>(ns2, beta2), message);
1517
}
1618

1719
using Eigen::Dynamic;
@@ -39,6 +41,6 @@ TEST(AgradDistributionsMultinomialLogit, check_varis_on_stack) {
3941
Matrix<var, Dynamic, 1> beta(3, 1);
4042
beta << log(0.3), log(0.5), log(0.2);
4143

42-
test::check_varis_on_stack(stan::math::multinomial_logit_lpmf<false>(ns, beta));
43-
test::check_varis_on_stack(stan::math::multinomial_logit_lpmf<true>(ns, beta));
44+
test::check_varis_on_stack(multinomial_logit_lpmf<false>(ns, beta));
45+
test::check_varis_on_stack(multinomial_logit_lpmf<true>(ns, beta));
4446
}

0 commit comments

Comments
 (0)