|
1 | 1 | #ifndef STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP |
2 | 2 | #define STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP |
3 | 3 |
|
4 | | -#include <stan/math/fwd/core.hpp> |
5 | | -#include <stan/math/fwd/fun/multiply.hpp> |
6 | | -#include <stan/math/fwd/fun/to_fvar.hpp> |
7 | 4 | #include <stan/math/prim/err.hpp> |
8 | 5 | #include <stan/math/prim/fun/Eigen.hpp> |
9 | 6 | #include <stan/math/prim/fun/mdivide_right.hpp> |
10 | 7 | #include <stan/math/prim/fun/multiply.hpp> |
11 | | -#include <stan/math/prim/fun/subtract.hpp> |
| 8 | +#include <stan/math/fwd/core.hpp> |
| 9 | +#include <stan/math/fwd/fun/multiply.hpp> |
| 10 | +#include <stan/math/fwd/fun/to_fvar.hpp> |
12 | 11 | #include <vector> |
13 | 12 |
|
14 | 13 | namespace stan { |
15 | 14 | namespace math { |
16 | 15 | /* |
17 | 16 | template <typename EigMat1, typename EigMat2, |
18 | 17 | require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr> |
19 | | -inline auto mdivide_right(const EigMat1& A, const EigMat2& b) { |
20 | | - using T1 = typename return_type_t<EigMat1>::Scalar; |
21 | | - using T2 = typename return_type_t<EigMat2>::Scalar; |
22 | | - using ret_scalar = return_type_t<EigMat1, EigMat2>; |
23 | | - constexpr int R1 = EigMat1::RowsAtCompileTime; |
24 | | - constexpr int C1 = EigMat1::ColsAtCompileTime; |
25 | | - constexpr int R2 = EigMat2::RowsAtCompileTime; |
26 | | - constexpr int C2 = EigMat2::ColsAtCompileTime; |
| 18 | +inline auto |
| 19 | +mdivide_right(const EigMat1& b, const EigMat2& A) { |
| 20 | + std::cout << "\nUsing 1: " << "\n"; |
| 21 | + using A_fvar_inner_type = typename value_type_t<EigMat2>::Scalar; |
| 22 | + using b_fvar_inner_type = typename value_type_t<EigMat1>::Scalar; |
| 23 | + using inner_ret_t = return_type_t<A_fvar_inner_type, b_fvar_inner_type>; |
| 24 | + constexpr auto R1 = EigMat1::RowsAtCompileTime; |
| 25 | + constexpr auto C1 = EigMat1::ColsAtCompileTime; |
| 26 | + constexpr auto R2 = EigMat2::RowsAtCompileTime; |
| 27 | + constexpr auto C2 = EigMat2::ColsAtCompileTime; |
27 | 28 |
|
28 | | - check_square("mdivide_right", "b", b); |
29 | | - check_multiplicable("mdivide_right", "A", A, "b", b); |
30 | | - if (b.size() == 0) { |
31 | | - return Eigen::Matrix<ret_scalar, R1, C2>(A.rows(), 0); |
| 29 | + check_square("mdivide_right", "A", A); |
| 30 | + check_multiplicable("mdivide_right", "b", b, "A", A); |
| 31 | + if (A.size() == 0) { |
| 32 | + using ret_t = decltype(mdivide_right(b.val(), A.val()).eval()); |
| 33 | + return promote_scalar_t<fvar<inner_ret_t>, ret_t>{b.rows(), 0}; |
32 | 34 | } |
33 | 35 |
|
| 36 | + Eigen::Matrix<A_fvar_inner_type, R2, C2> val_A(A.rows(), A.cols()); |
| 37 | + Eigen::Matrix<A_fvar_inner_type, R2, C2> deriv_A(A.rows(), A.cols()); |
34 | 38 |
|
35 | | - auto&& A_ref = to_ref(A); |
36 | | - Eigen::Matrix<T1, R1, C1> val_A = A_ref.val(); |
37 | | - Eigen::Matrix<T1, R1, C1> deriv_A = A_ref.d(); |
38 | | -
|
39 | | - auto&& b_ref = to_ref(b); |
40 | | - Eigen::Matrix<T2, R2, C2> val_b = b_ref.val(); |
41 | | - Eigen::Matrix<T2, R2, C2> deriv_b = b_ref.d(); |
| 39 | + const auto& A_ref = to_ref(A); |
| 40 | + for (int j = 0; j < A.cols(); j++) { |
| 41 | + for (int i = 0; i < A.rows(); i++) { |
| 42 | + val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_; |
| 43 | + deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_; |
| 44 | + } |
| 45 | + } |
42 | 46 |
|
43 | | - Eigen::Matrix<return_type_t<T1, T2>, R1, C2> A_mult_inv_b = |
44 | | -mdivide_right(val_A, val_b).template cast<return_type_t<T1, T2>>().eval(); |
45 | | - Eigen::Matrix<ret_scalar, R1, C2> res = A_mult_inv_b.template |
46 | | -cast<ret_scalar>().eval(); res.d() = subtract(mdivide_right(deriv_A, val_b), |
47 | | - multiply(A_mult_inv_b, mdivide_right(deriv_b, val_b))); |
48 | | - return res; |
| 47 | + Eigen::Matrix<b_fvar_inner_type, R1, C1> val_b(b.rows(), b.cols()); |
| 48 | + Eigen::Matrix<b_fvar_inner_type, R1, C1> deriv_b(b.rows(), b.cols()); |
| 49 | + const auto& b_ref = to_ref(b); |
| 50 | + for (Eigen::Index j = 0; j < b.cols(); j++) { |
| 51 | + for (Eigen::Index i = 0; i < b.rows(); i++) { |
| 52 | + val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_; |
| 53 | + deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_; |
| 54 | + } |
| 55 | + } |
| 56 | + auto A_mult_inv_b = mdivide_right(val_b, val_A).eval(); |
| 57 | + promote_scalar_t<fvar<inner_ret_t>, decltype(A_mult_inv_b)> ret(A_mult_inv_b.rows(), A_mult_inv_b.cols()); |
| 58 | + ret.val() = A_mult_inv_b; |
| 59 | + ret.d() = mdivide_right(deriv_b, val_A) |
| 60 | + - multiply(A_mult_inv_b, mdivide_right(deriv_A, val_A)); |
| 61 | + return ret; |
49 | 62 | } |
50 | 63 |
|
51 | 64 | template <typename EigMat1, typename EigMat2, |
52 | 65 | require_eigen_vt<is_fvar, EigMat1>* = nullptr, |
53 | | - require_eigen_vt<std::is_arithmetic, EigMat2>* = nullptr> |
54 | | -inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime, |
55 | | - EigMat2::ColsAtCompileTime> |
56 | | -mdivide_right(const EigMat1& A, const EigMat2& b) { |
57 | | - using T = typename value_type_t<EigMat1>::Scalar; |
58 | | - constexpr int R1 = EigMat1::RowsAtCompileTime; |
59 | | - constexpr int C1 = EigMat1::ColsAtCompileTime; |
60 | | - constexpr int C2 = EigMat2::ColsAtCompileTime; |
61 | | -
|
62 | | - check_square("mdivide_right", "b", b); |
63 | | - check_multiplicable("mdivide_right", "A", A, "b", b); |
64 | | - if (b.size() == 0) { |
65 | | - return {A.rows(), 0}; |
66 | | - } |
67 | | -
|
68 | | -
|
69 | | - auto&& A_ref = to_ref(A); |
70 | | - Eigen::Matrix<T, R1, C1> val_A = A_ref.val(); |
71 | | - Eigen::Matrix<T, R1, C1> deriv_A = A_ref.d(); |
72 | | - Eigen::Matrix<value_type_t<EigMat1>, R1, C2> res = mdivide_right(val_A, |
73 | | -b).template cast<value_type_t<EigMat1>>(); res.d() = mdivide_right(deriv_A, b); |
74 | | - return res; |
75 | | -} |
| 66 | + require_eigen_vt<is_var_or_arithmetic, EigMat2>* = nullptr> |
| 67 | + inline auto |
| 68 | + mdivide_right(const EigMat1& b, const EigMat2& A) { |
| 69 | + using T_return = return_type_t<EigMat1, EigMat2>; |
| 70 | + check_square("mdivide_right", "A", A); |
| 71 | + check_multiplicable("mdivide_right", "b", b, "A", A); |
| 72 | + if (A.size() == 0) { |
| 73 | + using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval()); |
| 74 | + return ret_type{b.rows(), 0}; |
| 75 | + } |
| 76 | + return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval(); |
| 77 | + } |
76 | 78 |
|
77 | 79 | template <typename EigMat1, typename EigMat2, |
78 | | - require_eigen_vt<std::is_arithmetic, EigMat1>* = nullptr, |
| 80 | + require_eigen_vt<is_var_or_arithmetic, EigMat1>* = nullptr, |
79 | 81 | require_eigen_vt<is_fvar, EigMat2>* = nullptr> |
80 | | -inline Eigen::Matrix<value_type_t<EigMat2>, EigMat1::RowsAtCompileTime, |
81 | | - EigMat2::ColsAtCompileTime> |
82 | | -mdivide_right(const EigMat1& A, const EigMat2& b) { |
83 | | - using T = typename value_type_t<EigMat2>::Scalar; |
84 | | - constexpr int R1 = EigMat1::RowsAtCompileTime; |
85 | | - constexpr int C1 = EigMat1::ColsAtCompileTime; |
86 | | - constexpr int R2 = EigMat2::RowsAtCompileTime; |
87 | | - constexpr int C2 = EigMat2::ColsAtCompileTime; |
88 | | -
|
89 | | - check_square("mdivide_right", "b", b); |
90 | | - check_multiplicable("mdivide_right", "A", A, "b", b); |
91 | | - if (b.size() == 0) { |
92 | | - return {A.rows(), 0}; |
93 | | - } |
94 | | -
|
95 | | - auto&& b_ref = to_ref(b); |
96 | | - Eigen::Matrix<T, R2, C2> val_b = b_ref.val(); |
97 | | - Eigen::Matrix<T, R2, C2> deriv_b = b_ref.d(); |
98 | | -
|
99 | | - Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(A, val_b); |
100 | | - Eigen::Matrix<value_type_t<EigMat2>, R1, C2> res = A_mult_inv_b.template |
101 | | -cast<value_type_t<EigMat2>>(); res.d() = -multiply(A_mult_inv_b, |
102 | | -mdivide_right(deriv_b, val_b)); return res; |
103 | | -} |
| 82 | + inline auto |
| 83 | + mdivide_right(const EigMat1& b, const EigMat2& A) { |
| 84 | + using T_return = return_type_t<EigMat1, EigMat2>; |
| 85 | + check_square("mdivide_right", "A", A); |
| 86 | + check_multiplicable("mdivide_right", "b", b, "A", A); |
| 87 | + if (A.size() == 0) { |
| 88 | + using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval()); |
| 89 | + return ret_type{b.rows(), 0}; |
| 90 | + } |
| 91 | + return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval(); |
| 92 | + } |
104 | 93 | */ |
105 | 94 | } // namespace math |
106 | 95 | } // namespace stan |
|
0 commit comments