|
9 | 9 | #include <stan/math/fwd/fun/multiply.hpp> |
10 | 10 | #include <stan/math/prim/err.hpp> |
11 | 11 | #include <stan/math/prim/fun/mdivide_left_tri_low.hpp> |
12 | | - |
| 12 | +#include <stan/math/prim/fun/eval.hpp> |
| 13 | +#include <stan/math/prim/fun/subtract.hpp> |
13 | 14 | namespace stan { |
14 | 15 | namespace math { |
15 | 16 |
|
16 | 17 | template <typename T1, typename T2, |
17 | 18 | require_all_eigen_vt<is_fvar, T1, T2>* = nullptr, |
18 | 19 | require_vt_same<T1, T2>* = nullptr> |
19 | | -inline Eigen::Matrix<value_type_t<T1>, T1::RowsAtCompileTime, |
20 | | - T2::ColsAtCompileTime> |
21 | | -mdivide_left_tri_low(const T1& A, const T2& b) { |
22 | | - using T = typename value_type_t<T1>::Scalar; |
23 | | - constexpr int S1 = T1::RowsAtCompileTime; |
24 | | - constexpr int C2 = T2::ColsAtCompileTime; |
| 20 | +inline Eigen::Matrix<value_type_t<T1>, std::decay_t<T1>::RowsAtCompileTime, |
| 21 | + std::decay_t<T2>::ColsAtCompileTime> |
| 22 | +mdivide_left_tri_low(T1&& A, T2&& b) { |
| 23 | + constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime; |
| 24 | + constexpr int C2 = std::decay_t<T2>::ColsAtCompileTime; |
25 | 25 |
|
26 | 26 | check_square("mdivide_left_tri_low", "A", A); |
27 | 27 | check_multiplicable("mdivide_left_tri_low", "A", A, "b", b); |
28 | 28 | if (A.size() == 0) { |
29 | 29 | return {0, b.cols()}; |
30 | 30 | } |
31 | | - |
32 | | - Eigen::Matrix<T, S1, S1> val_A(A.rows(), A.cols()); |
33 | | - Eigen::Matrix<T, S1, S1> deriv_A(A.rows(), A.cols()); |
34 | | - val_A.setZero(); |
35 | | - deriv_A.setZero(); |
36 | | - |
37 | | - const Eigen::Ref<const plain_type_t<T2>>& b_ref = b; |
38 | | - const Eigen::Ref<const plain_type_t<T1>>& A_ref = A; |
39 | | - for (size_type j = 0; j < A.cols(); j++) { |
40 | | - for (size_type i = j; i < A.rows(); i++) { |
41 | | - val_A(i, j) = A_ref(i, j).val_; |
42 | | - deriv_A(i, j) = A_ref(i, j).d_; |
43 | | - } |
44 | | - } |
45 | | - |
46 | | - Eigen::Matrix<T, S1, C2> inv_A_mult_b = mdivide_left(val_A, b_ref.val()); |
47 | | - |
48 | | - return to_fvar(inv_A_mult_b, |
49 | | - mdivide_left(val_A, b_ref.d()) |
50 | | - - multiply(mdivide_left(val_A, deriv_A), inv_A_mult_b)); |
| 31 | + decltype(auto) b_ref = to_ref(std::forward<T2>(b)); |
| 32 | + decltype(auto) A_ref = to_ref(std::forward<T1>(A)); |
| 33 | + auto inv_A_mult_b |
| 34 | + = eval(mdivide_left_tri<Eigen::Lower>(A_ref.val(), b_ref.val())); |
| 35 | + return to_fvar( |
| 36 | + inv_A_mult_b, |
| 37 | + subtract(mdivide_left_tri<Eigen::Lower>(A_ref.val(), b_ref.d()), |
| 38 | + multiply(mdivide_left_tri<Eigen::Lower>( |
| 39 | + A_ref.val(), |
| 40 | + A_ref.d().template triangularView<Eigen::Lower>()), |
| 41 | + inv_A_mult_b))); |
51 | 42 | } |
52 | 43 |
|
53 | 44 | template <typename T1, typename T2, require_eigen_t<T1>* = nullptr, |
54 | 45 | require_vt_same<double, T1>* = nullptr, |
55 | 46 | require_eigen_vt<is_fvar, T2>* = nullptr> |
56 | | -inline Eigen::Matrix<value_type_t<T2>, T1::RowsAtCompileTime, |
57 | | - T2::ColsAtCompileTime> |
58 | | -mdivide_left_tri_low(const T1& A, const T2& b) { |
59 | | - constexpr int S1 = T1::RowsAtCompileTime; |
60 | | - |
| 47 | +inline Eigen::Matrix<value_type_t<T2>, std::decay_t<T1>::RowsAtCompileTime, |
| 48 | + std::decay_t<T2>::ColsAtCompileTime> |
| 49 | +mdivide_left_tri_low(T1&& A, T2&& b) { |
| 50 | + constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime; |
61 | 51 | check_square("mdivide_left_tri_low", "A", A); |
62 | 52 | check_multiplicable("mdivide_left_tri_low", "A", A, "b", b); |
63 | 53 | if (A.size() == 0) { |
64 | 54 | return {0, b.cols()}; |
65 | 55 | } |
66 | | - |
67 | | - Eigen::Matrix<double, S1, S1> val_A(A.rows(), A.cols()); |
68 | | - val_A.setZero(); |
69 | | - |
70 | | - const Eigen::Ref<const plain_type_t<T2>>& b_ref = b; |
71 | | - const Eigen::Ref<const plain_type_t<T1>>& A_ref = A; |
72 | | - for (size_type j = 0; j < A.cols(); j++) { |
73 | | - for (size_type i = j; i < A.rows(); i++) { |
74 | | - val_A(i, j) = A_ref(i, j); |
75 | | - } |
76 | | - } |
77 | | - |
78 | | - return to_fvar(mdivide_left(val_A, b_ref.val()), |
79 | | - mdivide_left(val_A, b_ref.d())); |
| 56 | + decltype(auto) A_ref = to_ref(std::forward<T1>(A)); |
| 57 | + decltype(auto) b_ref = to_ref(std::forward<T2>(b)); |
| 58 | + return to_fvar(mdivide_left_tri<Eigen::Lower>(A_ref, b_ref.val()), |
| 59 | + mdivide_left_tri<Eigen::Lower>(A_ref, b_ref.d())); |
80 | 60 | } |
81 | 61 |
|
82 | 62 | template <typename T1, typename T2, require_eigen_vt<is_fvar, T1>* = nullptr, |
83 | | - require_eigen_t<T2>* = nullptr, |
84 | | - require_vt_same<double, T2>* = nullptr> |
85 | | -inline Eigen::Matrix<value_type_t<T1>, T1::RowsAtCompileTime, |
86 | | - T2::ColsAtCompileTime> |
87 | | -mdivide_left_tri_low(const T1& A, const T2& b) { |
88 | | - using T = typename value_type_t<T1>::Scalar; |
89 | | - constexpr int S1 = T1::RowsAtCompileTime; |
90 | | - constexpr int C2 = T2::ColsAtCompileTime; |
91 | | - |
| 63 | + require_eigen_vt<std::is_floating_point, T2>* = nullptr> |
| 64 | +inline Eigen::Matrix<value_type_t<T1>, std::decay_t<T1>::RowsAtCompileTime, |
| 65 | + std::decay_t<T2>::ColsAtCompileTime> |
| 66 | +mdivide_left_tri_low(T1&& A, T2&& b) { |
| 67 | + constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime; |
| 68 | + constexpr int C2 = std::decay_t<T2>::ColsAtCompileTime; |
92 | 69 | check_square("mdivide_left_tri_low", "A", A); |
93 | 70 | check_multiplicable("mdivide_left_tri_low", "A", A, "b", b); |
94 | 71 | if (A.size() == 0) { |
95 | 72 | return {0, b.cols()}; |
96 | 73 | } |
97 | | - |
98 | | - Eigen::Matrix<T, S1, S1> val_A(A.rows(), A.cols()); |
99 | | - Eigen::Matrix<T, S1, S1> deriv_A(A.rows(), A.cols()); |
100 | | - val_A.setZero(); |
101 | | - deriv_A.setZero(); |
102 | | - |
103 | | - const Eigen::Ref<const plain_type_t<T1>>& A_ref = A; |
104 | | - for (size_type j = 0; j < A.cols(); j++) { |
105 | | - for (size_type i = j; i < A.rows(); i++) { |
106 | | - val_A(i, j) = A_ref(i, j).val_; |
107 | | - deriv_A(i, j) = A_ref(i, j).d_; |
108 | | - } |
109 | | - } |
110 | | - |
111 | | - Eigen::Matrix<T, S1, C2> inv_A_mult_b = mdivide_left(val_A, b); |
112 | | - |
113 | | - return to_fvar(inv_A_mult_b, |
114 | | - -multiply(mdivide_left(val_A, deriv_A), inv_A_mult_b)); |
| 74 | + decltype(auto) A_ref = to_ref(std::forward<T1>(A)); |
| 75 | + auto inv_A_mult_b |
| 76 | + = eval(mdivide_left_tri<Eigen::Lower>(A_ref.val(), std::forward<T2>(b))); |
| 77 | + return to_fvar( |
| 78 | + inv_A_mult_b, |
| 79 | + -multiply( |
| 80 | + mdivide_left_tri<Eigen::Lower>( |
| 81 | + A_ref.val(), A_ref.d().template triangularView<Eigen::Lower>()), |
| 82 | + inv_A_mult_b)); |
115 | 83 | } |
116 | 84 |
|
117 | 85 | } // namespace math |
|
0 commit comments