Skip to content

Commit 71782ef

Browse files
committed
Fix scale_matrix_exp_multiply to use to_ref and make_holder
1 parent b863ce5 commit 71782ef

4 files changed

Lines changed: 47 additions & 36 deletions

File tree

stan/math/fwd/fun/multiply.hpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,34 @@ template <typename Mat1, typename Mat2,
1717
require_all_eigen_vt<is_fvar, Mat1, Mat2>* = nullptr,
1818
require_vt_same<Mat1, Mat2>* = nullptr,
1919
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
20-
inline auto multiply(const Mat1& m1, const Mat2& m2) {
20+
inline auto multiply(Mat1&& m1, Mat2&& m2) {
2121
check_multiplicable("multiply", "m1", m1, "m2", m2);
22-
return to_fvar(multiply(m1.val(), m2.val()),
23-
add(multiply(m1.val(), m2.d()), multiply(m1.d(), m2.val())));
22+
decltype(auto) m1_ref = to_ref(std::forward<Mat1>(m1));
23+
decltype(auto) m2_ref = to_ref(std::forward<Mat2>(m2));
24+
return to_fvar(multiply(m1_ref.val(), m2_ref.val()),
25+
add(multiply(m1_ref.val(), m2_ref.d()), multiply(m1_ref.d(), m2_ref.val())));
2426
}
2527

2628
template <typename Mat1, typename Mat2,
2729
require_eigen_vt<is_fvar, Mat1>* = nullptr,
2830
require_eigen_vt<std::is_floating_point, Mat2>* = nullptr,
2931
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
30-
inline auto multiply(const Mat1& m1, const Mat2& m2) {
32+
inline auto multiply(Mat1&& m1, Mat2&& m2) {
3133
check_multiplicable("multiply", "m1", m1, "m2", m2);
32-
return to_fvar(multiply(m1.val(), m2), multiply(m1.d(), m2));
34+
decltype(auto) m1_ref = to_ref(std::forward<Mat1>(m1));
35+
decltype(auto) m2_ref = to_ref(std::forward<Mat2>(m2));
36+
return to_fvar(multiply(m1_ref.val(), m2_ref), multiply(m1_ref.d(), m2_ref));
3337
}
3438

3539
template <typename Mat1, typename Mat2,
3640
require_eigen_vt<std::is_floating_point, Mat1>* = nullptr,
3741
require_eigen_vt<is_fvar, Mat2>* = nullptr,
3842
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
39-
inline auto multiply(const Mat1& m1, const Mat2& m2) {
43+
inline auto multiply(Mat1&& m1, Mat2&& m2) {
4044
check_multiplicable("multiply", "m1", m1, "m2", m2);
41-
return to_fvar(multiply(m1, m2.val()), multiply(m1, m2.d()));
45+
decltype(auto) m1_ref = to_ref(std::forward<Mat1>(m1));
46+
decltype(auto) m2_ref = to_ref(std::forward<Mat2>(m2));
47+
return to_fvar(multiply(m1_ref, m2_ref.val()), multiply(m1_ref, m2_ref.d()));
4248
}
4349

4450
} // namespace math

stan/math/prim/fun/matrix_exp.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@ namespace math {
2121
* @throw <code>std::invalid_argument</code> if the input matrix
2222
* is not square.
2323
*/
24-
template <typename T, typename = require_eigen_t<T>>
25-
inline plain_type_t<T> matrix_exp(const T& A_in) {
26-
using std::exp;
27-
const auto& A = A_in.eval();
24+
template <typename EigenMat, typename = require_eigen_t<EigenMat>>
25+
inline plain_type_t<EigenMat> matrix_exp(EigenMat&& A_in) {
26+
decltype(auto) A = to_ref(std::forward<EigenMat>(A_in));
27+
using T = std::decay_t<EigenMat>;
2828
check_square("matrix_exp", "input matrix", A);
2929
if constexpr (T::RowsAtCompileTime == 1 && T::ColsAtCompileTime == 1) {
3030
plain_type_t<T> res;
3131
res << exp(A(0));
3232
return res;
3333
}
34-
if (A_in.size() == 0) {
34+
if (A.size() == 0) {
3535
return {};
3636
}
3737
return (A.cols() == 2

stan/math/prim/fun/matrix_exp_action_handler.hpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
66
#include <stan/math/prim/fun/ceil.hpp>
7+
#include <stan/math/prim/fun/to_ref.hpp>
78
#include <unsupported/Eigen/MatrixFunctions>
89
#include <vector>
910
#include <cmath>
@@ -23,10 +24,10 @@ namespace math {
2324
* and t is double.
2425
*/
2526
class matrix_exp_action_handler {
26-
const int _p_max = 8;
27-
const int _m_max = 55;
28-
const double _tol = 1.1e-16; // from the paper, double precision: 2^-53
29-
const std::vector<double> _theta_m{
27+
static constexpr int _p_max = 8;
28+
static constexpr int _m_max = 55;
29+
static constexpr double _tol = 1.1e-16; // from the paper, double precision: 2^-53
30+
static constexpr std::array<double, 100> _theta_m{
3031
2.22044605e-16, 2.58095680e-08, 1.38634787e-05, 3.39716884e-04,
3132
2.40087636e-03, 9.06565641e-03, 2.38445553e-02, 4.99122887e-02,
3233
8.95776020e-02, 1.44182976e-01, 2.14235807e-01, 2.99615891e-01,
@@ -64,18 +65,18 @@ class matrix_exp_action_handler {
6465
template <typename EigMat1, typename EigMat2,
6566
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
6667
require_all_st_same<double, EigMat1, EigMat2>* = nullptr>
67-
inline Eigen::MatrixXd action(const EigMat1& mat, const EigMat2& b,
68-
const double& t = 1.0) {
69-
Eigen::MatrixXd A = mat;
68+
inline Eigen::MatrixXd action(EigMat1&& mat, EigMat2&& b,
69+
const double t = 1.0) {
70+
decltype(auto) A = to_ref(std::forward<EigMat1>(mat));
7071
double mu = A.trace() / A.rows();
7172
A.diagonal().array() -= mu;
7273

7374
int m, s;
74-
set_approx_order(A, b, t, m, s);
75+
decltype(auto) b_eval = to_ref(std::forward<EigMat2>(b));
76+
set_approx_order(A, b_eval, t, m, s);
7577

7678
double eta = exp(t * mu / s);
7779

78-
const auto& b_eval = b.eval();
7980
Eigen::MatrixXd f = b_eval;
8081
Eigen::MatrixXd bi = b_eval;
8182

@@ -102,8 +103,11 @@ class matrix_exp_action_handler {
102103
*
103104
* @param x matrix
104105
*/
105-
double matrix_operator_inf_norm(Eigen::MatrixXd const& x) {
106-
return x.cwiseAbs().rowwise().sum().maxCoeff();
106+
template <typename EigenMat>
107+
double matrix_operator_inf_norm(EigenMat&& x) {
108+
return make_holder([](auto&& x_) {
109+
return x_.cwiseAbs().rowwise().sum().maxCoeff();
110+
}, std::forward<EigenMat>(x));
107111
}
108112

109113
/**
@@ -125,15 +129,16 @@ class matrix_exp_action_handler {
125129
*/
126130
template <typename EigMat1, require_all_eigen_t<EigMat1>* = nullptr,
127131
require_all_st_same<double, EigMat1>* = nullptr>
128-
double mat_power_1_norm(const EigMat1& mat, int m) {
129-
if ((mat.array() > 0.0).all()) {
130-
Eigen::VectorXd e = Eigen::VectorXd::Constant(mat.rows(), 1.0);
132+
double mat_power_1_norm(EigMat1&& mat, const int m) {
133+
auto&& mat_ref = to_ref(std::forward<EigMat1>(mat));
134+
if ((mat_ref.array() > 0.0).all()) {
135+
Eigen::VectorXd e = Eigen::VectorXd::Constant(mat_ref.rows(), 1.0);
131136
for (int j = 0; j < m; ++j) {
132-
e = mat.transpose() * e;
137+
e = mat_ref.transpose() * e;
133138
}
134139
return e.lpNorm<Eigen::Infinity>();
135140
} else {
136-
return mat.pow(m).cwiseAbs().colwise().sum().maxCoeff();
141+
return mat_ref.pow(m).cwiseAbs().colwise().sum().maxCoeff();
137142
}
138143
}
139144

@@ -156,8 +161,8 @@ class matrix_exp_action_handler {
156161
template <typename EigMat1, typename EigMat2,
157162
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
158163
require_all_st_same<double, EigMat1, EigMat2>* = nullptr>
159-
inline void set_approx_order(const EigMat1& mat, const EigMat2& b,
160-
const double& t, int& m, int& s) {
164+
inline void set_approx_order(EigMat1&& mat, EigMat2&& b,
165+
const double t, int& m, int& s) {
161166
if (t < _tol) {
162167
m = 0;
163168
s = 1;

stan/math/prim/fun/scale_matrix_exp_multiply.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ namespace math {
2626
*/
2727
template <typename EigMat1, typename EigMat2,
2828
require_all_eigen_vt<std::is_arithmetic, EigMat1, EigMat2>* = nullptr>
29-
inline Eigen::Matrix<double, Eigen::Dynamic, EigMat2::ColsAtCompileTime>
30-
scale_matrix_exp_multiply(const double& t, const EigMat1& A, const EigMat2& B) {
29+
inline Eigen::Matrix<double, Eigen::Dynamic, std::decay_t<EigMat2>::ColsAtCompileTime>
30+
scale_matrix_exp_multiply(const double t, EigMat1&& A, EigMat2&& B) {
3131
check_square("scale_matrix_exp_multiply", "input matrix", A);
3232
check_multiplicable("scale_matrix_exp_multiply", "A", A, "B", B);
3333
if (A.size() == 0) {
3434
return {0, B.cols()};
3535
}
3636

37-
return matrix_exp_action_handler().action(A, B, t);
37+
return matrix_exp_action_handler().action(std::forward<EigMat1>(A), std::forward<EigMat2>(B), t);
3838
}
3939

4040
/**
@@ -56,15 +56,15 @@ template <typename Tt, typename EigMat1, typename EigMat2,
5656
require_any_autodiff_scalar_t<Tt, value_type_t<EigMat1>,
5757
value_type_t<EigMat2>>* = nullptr>
5858
inline Eigen::Matrix<return_type_t<Tt, EigMat1, EigMat2>, Eigen::Dynamic,
59-
EigMat2::ColsAtCompileTime>
60-
scale_matrix_exp_multiply(const Tt& t, const EigMat1& A, const EigMat2& B) {
59+
std::decay_t<EigMat2>::ColsAtCompileTime>
60+
scale_matrix_exp_multiply(const Tt t, EigMat1&& A, EigMat2&& B) {
6161
check_square("scale_matrix_exp_multiply", "input matrix", A);
6262
check_multiplicable("scale_matrix_exp_multiply", "A", A, "B", B);
6363
if (A.size() == 0) {
6464
return {0, B.cols()};
6565
}
6666

67-
return multiply(matrix_exp(multiply(A, t)), B);
67+
return multiply(matrix_exp(multiply(std::forward<EigMat1>(A), t)), std::forward<EigMat2>(B));
6868
}
6969

7070
} // namespace math

0 commit comments

Comments
 (0)