Skip to content

Commit 93b1a13

Browse files
committed
Merge commit 'd48a48744910f99cd94abf0a5789ad6a8031c758' into HEAD
2 parents 690fcea + d48a487 commit 93b1a13

33 files changed

Lines changed: 16361 additions & 144 deletions

stan/math/fwd/fun/determinant.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
7+
#include <stan/math/prim/fun/to_ref.hpp>
78
#include <stan/math/fwd/core.hpp>
89

910
namespace stan {
@@ -12,9 +13,10 @@ namespace math {
1213
template <typename EigMat, require_eigen_vt<is_fvar, EigMat>* = nullptr>
1314
inline value_type_t<EigMat> determinant(const EigMat& m) {
1415
check_square("determinant", "m", m);
16+
const auto& m_ref = to_ref(m);
1517

16-
const typename value_type_t<EigMat>::Scalar vals = m.val().determinant();
17-
return {vals, vals * (m.val().inverse() * m.d()).trace()};
18+
const typename value_type_t<EigMat>::Scalar vals = m_ref.val().determinant();
19+
return {vals, vals * (m_ref.val().inverse() * m_ref.d()).trace()};
1820
}
1921

2022
} // namespace math

stan/math/fwd/fun/log_sum_exp.hpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/prim/fun/Eigen.hpp>
88
#include <stan/math/prim/fun/constants.hpp>
99
#include <stan/math/prim/fun/log_sum_exp.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
1011
#include <cmath>
1112
#include <vector>
1213

@@ -52,15 +53,17 @@ inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
5253
*/
5354
template <typename T, require_container_st<is_fvar, T>* = nullptr>
5455
inline auto log_sum_exp(const T& x) {
55-
return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
56-
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
57-
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
58-
mat_type vals = v.val();
59-
mat_type exp_vals = vals.array().exp();
56+
return apply_vector_unary<ref_type_t<T>>::reduce(
57+
to_ref(x), [&](const auto& v) {
58+
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
59+
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
60+
mat_type vals = v.val();
61+
mat_type exp_vals = vals.array().exp();
6062

61-
return fvar<T_fvar_inner>(
62-
log_sum_exp(vals), v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
63-
});
63+
return fvar<T_fvar_inner>(
64+
log_sum_exp(vals),
65+
v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
66+
});
6467
}
6568

6669
} // namespace math

stan/math/fwd/fun/quad_form.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/fwd/core.hpp>
55
#include <stan/math/fwd/fun/multiply.hpp>
66
#include <stan/math/prim/fun/dot_product.hpp>
7+
#include <stan/math/prim/fun/to_ref.hpp>
78

89
namespace stan {
910
namespace math {
@@ -31,7 +32,8 @@ inline promote_scalar_t<return_type_t<EigMat1, EigMat2>, EigMat2> quad_form(
3132
const EigMat1& A, const EigMat2& B) {
3233
check_square("quad_form", "A", A);
3334
check_multiplicable("quad_form", "A", A, "B", B);
34-
return multiply(B.transpose(), multiply(A, B));
35+
const auto& B_ref = to_ref(B);
36+
return multiply(B_ref.transpose(), multiply(A, B_ref));
3537
}
3638

3739
/**
@@ -53,7 +55,8 @@ inline return_type_t<EigMat, ColVec> quad_form(const EigMat& A,
5355
const ColVec& B) {
5456
check_square("quad_form", "A", A);
5557
check_multiplicable("quad_form", "A", A, "B", B);
56-
return dot_product(B, multiply(A, B));
58+
const auto& B_ref = to_ref(B);
59+
return dot_product(B_ref, multiply(A, B_ref));
5760
}
5861

5962
} // namespace math

stan/math/fwd/fun/quad_form_sym.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/fwd/core.hpp>
55
#include <stan/math/fwd/fun/multiply.hpp>
66
#include <stan/math/prim/fun/dot_product.hpp>
7+
#include <stan/math/prim/fun/to_ref.hpp>
78

89
namespace stan {
910
namespace math {
@@ -31,7 +32,9 @@ inline promote_scalar_t<return_type_t<EigMat1, EigMat2>, EigMat2> quad_form_sym(
3132
using T_ret = return_type_t<EigMat1, EigMat2>;
3233
check_multiplicable("quad_form_sym", "A", A, "B", B);
3334
check_symmetric("quad_form_sym", "A", A);
34-
promote_scalar_t<T_ret, EigMat2> ret(multiply(B.transpose(), multiply(A, B)));
35+
const auto& B_ref = to_ref(B);
36+
promote_scalar_t<T_ret, EigMat2> ret(
37+
multiply(B_ref.transpose(), multiply(A, B_ref)));
3538
return T_ret(0.5) * (ret + ret.transpose());
3639
}
3740

@@ -54,7 +57,8 @@ inline return_type_t<EigMat, ColVec> quad_form_sym(const EigMat& A,
5457
const ColVec& B) {
5558
check_multiplicable("quad_form_sym", "A", A, "B", B);
5659
check_symmetric("quad_form_sym", "A", A);
57-
return dot_product(B, multiply(A, B));
60+
const auto& B_ref = to_ref(B);
61+
return dot_product(B_ref, multiply(A, B_ref));
5862
}
5963

6064
} // namespace math

stan/math/fwd/fun/softmax.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
#define STAN_MATH_FWD_FUN_SOFTMAX_HPP
33

44
#include <stan/math/fwd/core.hpp>
5+
#include <stan/math/fwd/fun/value_of.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
67
#include <stan/math/prim/fun/softmax.hpp>
8+
#include <stan/math/prim/fun/to_ref.hpp>
9+
#include <stan/math/prim/fun/value_of.hpp>
710

811
namespace stan {
912
namespace math {
@@ -14,13 +17,9 @@ inline auto softmax(const ColVec& alpha) {
1417
using Eigen::Dynamic;
1518
using Eigen::Matrix;
1619
using T = typename value_type_t<ColVec>::Scalar;
20+
const auto& alpha_ref = to_ref(alpha);
1721

18-
Matrix<T, Dynamic, 1> alpha_t(alpha.size());
19-
for (int k = 0; k < alpha.size(); ++k) {
20-
alpha_t.coeffRef(k) = alpha.coeff(k).val_;
21-
}
22-
23-
Matrix<T, Dynamic, 1> softmax_alpha_t = softmax(alpha_t);
22+
Matrix<T, Dynamic, 1> softmax_alpha_t = softmax(value_of(alpha_ref));
2423

2524
Matrix<fvar<T>, Dynamic, 1> softmax_alpha(alpha.size());
2625
for (int k = 0; k < alpha.size(); ++k) {
@@ -30,12 +29,12 @@ inline auto softmax(const ColVec& alpha) {
3029

3130
for (int m = 0; m < alpha.size(); ++m) {
3231
T negative_alpha_m_d_times_softmax_alpha_t_m
33-
= -alpha.coeff(m).d_ * softmax_alpha_t.coeff(m);
32+
= -alpha_ref.coeff(m).d_ * softmax_alpha_t.coeff(m);
3433
for (int k = 0; k < alpha.size(); ++k) {
3534
if (m == k) {
3635
softmax_alpha.coeffRef(k).d_
3736
+= softmax_alpha_t.coeff(k)
38-
* (alpha.coeff(m).d_
37+
* (alpha_ref.coeff(m).d_
3938
+ negative_alpha_m_d_times_softmax_alpha_t_m);
4039
} else {
4140
softmax_alpha.coeffRef(k).d_

stan/math/fwd/fun/tcrossprod.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
6+
#include <stan/math/prim/fun/to_ref.hpp>
67
#include <stan/math/prim/fun/transpose.hpp>
78
#include <stan/math/fwd/fun/multiply.hpp>
89

@@ -16,7 +17,8 @@ tcrossprod(const EigMat& m) {
1617
if (m.rows() == 0) {
1718
return {};
1819
}
19-
return multiply(m, m.transpose());
20+
const auto& m_ref = to_ref(m);
21+
return multiply(m_ref, m_ref.transpose());
2022
}
2123

2224
} // namespace math

stan/math/fwd/fun/trace_quad_form.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/fun/multiply.hpp>
77
#include <stan/math/prim/fun/transpose.hpp>
88
#include <stan/math/prim/fun/trace.hpp>
9+
#include <stan/math/prim/fun/to_ref.hpp>
910
#include <stan/math/fwd/core.hpp>
1011

1112
namespace stan {
@@ -18,7 +19,8 @@ inline return_type_t<EigMat1, EigMat2> trace_quad_form(const EigMat1& A,
1819
const EigMat2& B) {
1920
check_square("trace_quad_form", "A", A);
2021
check_multiplicable("trace_quad_form", "A", A, "B", B);
21-
return B.cwiseProduct(multiply(A, B)).sum();
22+
const auto& B_ref = to_ref(B);
23+
return B_ref.cwiseProduct(multiply(A, B_ref)).sum();
2224
}
2325

2426
} // namespace math

stan/math/prim/err/check_symmetric.hpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,26 @@ namespace math {
1919
* Check if the specified matrix is symmetric.
2020
* The error message is either 0 or 1 indexed, specified by
2121
* <code>stan::error_index::value</code>.
22-
* @tparam T_y Type of scalar
22+
* @tparam EigMat Type of matrix
2323
* @param function Function name (for error messages)
2424
* @param name Variable name (for error messages)
2525
* @param y Matrix to test
2626
* @throw <code>std::invalid_argument</code> if the matrix is not square.
2727
* @throw <code>std::domain_error</code> if any element not on the
2828
* main diagonal is <code>NaN</code>
2929
*/
30-
template <typename T_y>
31-
inline void check_symmetric(
32-
const char* function, const char* name,
33-
const Eigen::Matrix<T_y, Eigen::Dynamic, Eigen::Dynamic>& y) {
30+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
31+
inline void check_symmetric(const char* function, const char* name,
32+
const EigMat& y) {
3433
check_square(function, name, y);
3534
using std::fabs;
36-
using size_type
37-
= index_type_t<Eigen::Matrix<T_y, Eigen::Dynamic, Eigen::Dynamic>>;
3835

39-
size_type k = y.rows();
36+
Eigen::Index k = y.rows();
4037
if (k <= 1) {
4138
return;
4239
}
43-
for (size_type m = 0; m < k; ++m) {
44-
for (size_type n = m + 1; n < k; ++n) {
40+
for (Eigen::Index m = 0; m < k; ++m) {
41+
for (Eigen::Index n = m + 1; n < k; ++n) {
4542
if (!(fabs(value_of(y(m, n)) - value_of(y(n, m)))
4643
<= CONSTRAINT_TOLERANCE)) {
4744
std::ostringstream msg1;

stan/math/prim/fun/cholesky_decompose.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
3535
inline Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
3636
EigMat::ColsAtCompileTime>
3737
cholesky_decompose(const EigMat& m) {
38-
eval_return_type_t<EigMat>& m_eval = m.eval();
38+
const eval_return_type_t<EigMat>& m_eval = m.eval();
3939
check_symmetric("cholesky_decompose", "m", m_eval);
4040
check_not_nan("cholesky_decompose", "m", m_eval);
4141
Eigen::LLT<Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
@@ -66,7 +66,7 @@ template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
6666
inline Eigen::Matrix<double, EigMat::RowsAtCompileTime,
6767
EigMat::ColsAtCompileTime>
6868
cholesky_decompose(const EigMat& m) {
69-
eval_return_type_t<EigMat>& m_eval = m.eval();
69+
const eval_return_type_t<EigMat>& m_eval = m.eval();
7070
check_not_nan("cholesky_decompose", "m", m_eval);
7171
#ifdef STAN_OPENCL
7272
if (m.rows() >= opencl_context.tuning_opts().cholesky_size_worth_transfer) {

stan/math/prim/fun/inverse_spd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ inverse_spd(const EigMat& m) {
2525
using Eigen::LDLT;
2626
using Eigen::Matrix;
2727
using Scalar = value_type_t<EigMat>;
28-
check_symmetric("inverse_spd", "m", m);
2928
if (m.size() == 0) {
3029
return {};
3130
}
3231
const Eigen::Ref<const plain_type_t<EigMat>>& m_ref = m;
32+
check_symmetric("inverse_spd", "m", m_ref);
3333
plain_type_t<EigMat> mmt = 0.5 * (m_ref + m_ref.transpose());
3434
LDLT<plain_type_t<EigMat>> ldlt(mmt);
3535
if (ldlt.info() != Eigen::Success) {

0 commit comments

Comments
 (0)