11#ifndef STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP
22#define STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP
33
4+ #include < stan/math/fwd/core.hpp>
5+ #include < stan/math/fwd/fun/multiply.hpp>
6+ #include < stan/math/fwd/fun/to_fvar.hpp>
47#include < stan/math/prim/err.hpp>
58#include < stan/math/prim/fun/Eigen.hpp>
69#include < stan/math/prim/fun/mdivide_right.hpp>
710#include < stan/math/prim/fun/multiply.hpp>
8- #include < stan/math/fwd/core.hpp>
9- #include < stan/math/fwd/fun/multiply.hpp>
10- #include < stan/math/fwd/fun/to_fvar.hpp>
11+ #include < stan/math/prim/fun/subtract.hpp>
1112#include < vector>
1213
1314namespace stan {
1415namespace math {
15-
16+ /*
1617template <typename EigMat1, typename EigMat2,
17- require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr ,
18- require_vt_same<EigMat1, EigMat2>* = nullptr >
19- inline Eigen::Matrix<value_type_t <EigMat1>, EigMat1::RowsAtCompileTime,
20- EigMat2::ColsAtCompileTime>
21- mdivide_right (const EigMat1& A, const EigMat2& b) {
22- using T = typename value_type_t <EigMat1>::Scalar;
18+ require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr>
19+ inline auto mdivide_right(const EigMat1& A, const EigMat2& b) {
20+ using T1 = typename return_type_t<EigMat1>::Scalar;
21+ using T2 = typename return_type_t<EigMat2>::Scalar;
22+ using ret_scalar = return_type_t<EigMat1, EigMat2>;
2323 constexpr int R1 = EigMat1::RowsAtCompileTime;
2424 constexpr int C1 = EigMat1::ColsAtCompileTime;
2525 constexpr int R2 = EigMat2::RowsAtCompileTime;
@@ -28,35 +28,24 @@ mdivide_right(const EigMat1& A, const EigMat2& b) {
2828 check_square("mdivide_right", "b", b);
2929 check_multiplicable("mdivide_right", "A", A, "b", b);
3030 if (b.size() == 0) {
31- return { A.rows (), 0 } ;
31+ return Eigen::Matrix<ret_scalar, R1, C2>( A.rows(), 0) ;
3232 }
3333
34- Eigen::Matrix<T, R1 , C1 > val_A (A.rows (), A.cols ());
35- Eigen::Matrix<T, R1 , C1 > deriv_A (A.rows (), A.cols ());
36- Eigen::Matrix<T, R2 , C2 > val_b (b.rows (), b.cols ());
37- Eigen::Matrix<T, R2 , C2 > deriv_b (b.rows (), b.cols ());
38-
39- const Eigen::Ref<const plain_type_t <EigMat1>>& A_ref = A;
40- for (int j = 0 ; j < A.cols (); j++) {
41- for (int i = 0 ; i < A.rows (); i++) {
42- val_A.coeffRef (i, j) = A_ref.coeff (i, j).val_ ;
43- deriv_A.coeffRef (i, j) = A_ref.coeff (i, j).d_ ;
44- }
45- }
4634
47- const Eigen::Ref<const plain_type_t <EigMat2>>& b_ref = b;
48- for (int j = 0 ; j < b.cols (); j++) {
49- for (int i = 0 ; i < b.rows (); i++) {
50- val_b.coeffRef (i, j) = b_ref.coeff (i, j).val_ ;
51- deriv_b.coeffRef (i, j) = b_ref.coeff (i, j).d_ ;
52- }
53- }
35+ auto&& A_ref = to_ref(A);
36+ Eigen::Matrix<T1, R1, C1> val_A = A_ref.val();
37+ Eigen::Matrix<T1, R1, C1> deriv_A = A_ref.d();
5438
55- Eigen::Matrix<T, R1 , C2 > A_mult_inv_b = mdivide_right (val_A, val_b);
39+ auto&& b_ref = to_ref(b);
40+ Eigen::Matrix<T2, R2, C2> val_b = b_ref.val();
41+ Eigen::Matrix<T2, R2, C2> deriv_b = b_ref.d();
5642
57- return to_fvar (A_mult_inv_b,
58- mdivide_right (deriv_A, val_b)
59- - A_mult_inv_b * mdivide_right (deriv_b, val_b));
43+ Eigen::Matrix<return_type_t<T1, T2>, R1, C2> A_mult_inv_b =
44+ mdivide_right(val_A, val_b).template cast<return_type_t<T1, T2>>().eval();
45+ Eigen::Matrix<ret_scalar, R1, C2> res = A_mult_inv_b.template
46+ cast<ret_scalar>().eval(); res.d() = subtract(mdivide_right(deriv_A, val_b),
47+ multiply(A_mult_inv_b, mdivide_right(deriv_b, val_b)));
48+ return res;
6049}
6150
6251template <typename EigMat1, typename EigMat2,
@@ -68,25 +57,21 @@ mdivide_right(const EigMat1& A, const EigMat2& b) {
6857 using T = typename value_type_t<EigMat1>::Scalar;
6958 constexpr int R1 = EigMat1::RowsAtCompileTime;
7059 constexpr int C1 = EigMat1::ColsAtCompileTime;
60+ constexpr int C2 = EigMat2::ColsAtCompileTime;
7161
7262 check_square("mdivide_right", "b", b);
7363 check_multiplicable("mdivide_right", "A", A, "b", b);
7464 if (b.size() == 0) {
7565 return {A.rows(), 0};
7666 }
7767
78- Eigen::Matrix<T, R1 , C1 > val_A (A.rows (), A.cols ());
79- Eigen::Matrix<T, R1 , C1 > deriv_A (A.rows (), A.cols ());
8068
81- const Eigen::Ref<const plain_type_t <EigMat1>>& A_ref = A;
82- for (int j = 0 ; j < A.cols (); j++) {
83- for (int i = 0 ; i < A.rows (); i++) {
84- val_A.coeffRef (i, j) = A_ref.coeff (i, j).val_ ;
85- deriv_A.coeffRef (i, j) = A_ref.coeff (i, j).d_ ;
86- }
87- }
88-
89- return to_fvar (mdivide_right (val_A, b), mdivide_right (deriv_A, b));
69+ auto&& A_ref = to_ref(A);
70+ Eigen::Matrix<T, R1, C1> val_A = A_ref.val();
71+ Eigen::Matrix<T, R1, C1> deriv_A = A_ref.d();
72+ Eigen::Matrix<value_type_t<EigMat1>, R1, C2> res = mdivide_right(val_A,
73+ b).template cast<value_type_t<EigMat1>>(); res.d() = mdivide_right(deriv_A, b);
74+ return res;
9075}
9176
9277template <typename EigMat1, typename EigMat2,
@@ -107,22 +92,16 @@ mdivide_right(const EigMat1& A, const EigMat2& b) {
10792 return {A.rows(), 0};
10893 }
10994
110- Eigen::Matrix<T, R2 , C2 > val_b (b.rows (), b.cols ());
111- Eigen::Matrix<T, R2 , C2 > deriv_b (b.rows (), b.cols ());
112-
113- const Eigen::Ref<const plain_type_t <EigMat2>>& b_ref = b;
114- for (int j = 0 ; j < b.cols (); j++) {
115- for (int i = 0 ; i < b.rows (); i++) {
116- val_b.coeffRef (i, j) = b_ref.coeff (i, j).val_ ;
117- deriv_b.coeffRef (i, j) = b_ref.coeff (i, j).d_ ;
118- }
119- }
95+ auto&& b_ref = to_ref(b);
96+ Eigen::Matrix<T, R2, C2> val_b = b_ref.val();
97+ Eigen::Matrix<T, R2, C2> deriv_b = b_ref.d();
12098
12199 Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(A, val_b);
122-
123- return to_fvar (A_mult_inv_b, -A_mult_inv_b * mdivide_right (deriv_b, val_b));
100+ Eigen::Matrix<value_type_t<EigMat2>, R1, C2> res = A_mult_inv_b.template
101+ cast<value_type_t<EigMat2>>(); res.d() = -multiply(A_mult_inv_b,
102+ mdivide_right(deriv_b, val_b)); return res;
124103}
125-
104+ */
126105} // namespace math
127106} // namespace stan
128107#endif
0 commit comments