Skip to content

Commit ff1510f

Browse files
committed
remove forward mode mdivide_right
1 parent a04ffbf commit ff1510f

5 files changed

Lines changed: 142 additions & 120 deletions

File tree

stan/math/fwd/fun/mdivide_right.hpp

Lines changed: 66 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,95 @@
11
#ifndef STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP
22
#define STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP
33

4-
#include <stan/math/fwd/core.hpp>
5-
#include <stan/math/fwd/fun/multiply.hpp>
6-
#include <stan/math/fwd/fun/to_fvar.hpp>
74
#include <stan/math/prim/err.hpp>
85
#include <stan/math/prim/fun/Eigen.hpp>
96
#include <stan/math/prim/fun/mdivide_right.hpp>
107
#include <stan/math/prim/fun/multiply.hpp>
11-
#include <stan/math/prim/fun/subtract.hpp>
8+
#include <stan/math/fwd/core.hpp>
9+
#include <stan/math/fwd/fun/multiply.hpp>
10+
#include <stan/math/fwd/fun/to_fvar.hpp>
1211
#include <vector>
1312

1413
namespace stan {
1514
namespace math {
1615
/*
1716
template <typename EigMat1, typename EigMat2,
1817
require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr>
19-
inline auto mdivide_right(const EigMat1& A, const EigMat2& b) {
20-
using T1 = typename return_type_t<EigMat1>::Scalar;
21-
using T2 = typename return_type_t<EigMat2>::Scalar;
22-
using ret_scalar = return_type_t<EigMat1, EigMat2>;
23-
constexpr int R1 = EigMat1::RowsAtCompileTime;
24-
constexpr int C1 = EigMat1::ColsAtCompileTime;
25-
constexpr int R2 = EigMat2::RowsAtCompileTime;
26-
constexpr int C2 = EigMat2::ColsAtCompileTime;
18+
inline auto
19+
mdivide_right(const EigMat1& b, const EigMat2& A) {
20+
std::cout << "\nUsing 1: " << "\n";
21+
using A_fvar_inner_type = typename value_type_t<EigMat2>::Scalar;
22+
using b_fvar_inner_type = typename value_type_t<EigMat1>::Scalar;
23+
using inner_ret_t = return_type_t<A_fvar_inner_type, b_fvar_inner_type>;
24+
constexpr auto R1 = EigMat1::RowsAtCompileTime;
25+
constexpr auto C1 = EigMat1::ColsAtCompileTime;
26+
constexpr auto R2 = EigMat2::RowsAtCompileTime;
27+
constexpr auto C2 = EigMat2::ColsAtCompileTime;
2728
28-
check_square("mdivide_right", "b", b);
29-
check_multiplicable("mdivide_right", "A", A, "b", b);
30-
if (b.size() == 0) {
31-
return Eigen::Matrix<ret_scalar, R1, C2>(A.rows(), 0);
29+
check_square("mdivide_right", "A", A);
30+
check_multiplicable("mdivide_right", "b", b, "A", A);
31+
if (A.size() == 0) {
32+
using ret_t = decltype(mdivide_right(b.val(), A.val()).eval());
33+
return promote_scalar_t<fvar<inner_ret_t>, ret_t>{b.rows(), 0};
3234
}
3335
36+
Eigen::Matrix<A_fvar_inner_type, R2, C2> val_A(A.rows(), A.cols());
37+
Eigen::Matrix<A_fvar_inner_type, R2, C2> deriv_A(A.rows(), A.cols());
3438
35-
auto&& A_ref = to_ref(A);
36-
Eigen::Matrix<T1, R1, C1> val_A = A_ref.val();
37-
Eigen::Matrix<T1, R1, C1> deriv_A = A_ref.d();
38-
39-
auto&& b_ref = to_ref(b);
40-
Eigen::Matrix<T2, R2, C2> val_b = b_ref.val();
41-
Eigen::Matrix<T2, R2, C2> deriv_b = b_ref.d();
39+
const auto& A_ref = to_ref(A);
40+
for (int j = 0; j < A.cols(); j++) {
41+
for (int i = 0; i < A.rows(); i++) {
42+
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
43+
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
44+
}
45+
}
4246
43-
Eigen::Matrix<return_type_t<T1, T2>, R1, C2> A_mult_inv_b =
44-
mdivide_right(val_A, val_b).template cast<return_type_t<T1, T2>>().eval();
45-
Eigen::Matrix<ret_scalar, R1, C2> res = A_mult_inv_b.template
46-
cast<ret_scalar>().eval(); res.d() = subtract(mdivide_right(deriv_A, val_b),
47-
multiply(A_mult_inv_b, mdivide_right(deriv_b, val_b)));
48-
return res;
47+
Eigen::Matrix<b_fvar_inner_type, R1, C1> val_b(b.rows(), b.cols());
48+
Eigen::Matrix<b_fvar_inner_type, R1, C1> deriv_b(b.rows(), b.cols());
49+
const auto& b_ref = to_ref(b);
50+
for (Eigen::Index j = 0; j < b.cols(); j++) {
51+
for (Eigen::Index i = 0; i < b.rows(); i++) {
52+
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
53+
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
54+
}
55+
}
56+
auto A_mult_inv_b = mdivide_right(val_b, val_A).eval();
57+
promote_scalar_t<fvar<inner_ret_t>, decltype(A_mult_inv_b)> ret(A_mult_inv_b.rows(), A_mult_inv_b.cols());
58+
ret.val() = A_mult_inv_b;
59+
ret.d() = mdivide_right(deriv_b, val_A)
60+
- multiply(A_mult_inv_b, mdivide_right(deriv_A, val_A));
61+
return ret;
4962
}
5063
5164
template <typename EigMat1, typename EigMat2,
5265
require_eigen_vt<is_fvar, EigMat1>* = nullptr,
53-
require_eigen_vt<std::is_arithmetic, EigMat2>* = nullptr>
54-
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
55-
EigMat2::ColsAtCompileTime>
56-
mdivide_right(const EigMat1& A, const EigMat2& b) {
57-
using T = typename value_type_t<EigMat1>::Scalar;
58-
constexpr int R1 = EigMat1::RowsAtCompileTime;
59-
constexpr int C1 = EigMat1::ColsAtCompileTime;
60-
constexpr int C2 = EigMat2::ColsAtCompileTime;
61-
62-
check_square("mdivide_right", "b", b);
63-
check_multiplicable("mdivide_right", "A", A, "b", b);
64-
if (b.size() == 0) {
65-
return {A.rows(), 0};
66-
}
67-
68-
69-
auto&& A_ref = to_ref(A);
70-
Eigen::Matrix<T, R1, C1> val_A = A_ref.val();
71-
Eigen::Matrix<T, R1, C1> deriv_A = A_ref.d();
72-
Eigen::Matrix<value_type_t<EigMat1>, R1, C2> res = mdivide_right(val_A,
73-
b).template cast<value_type_t<EigMat1>>(); res.d() = mdivide_right(deriv_A, b);
74-
return res;
75-
}
66+
require_eigen_vt<is_var_or_arithmetic, EigMat2>* = nullptr>
67+
inline auto
68+
mdivide_right(const EigMat1& b, const EigMat2& A) {
69+
using T_return = return_type_t<EigMat1, EigMat2>;
70+
check_square("mdivide_right", "A", A);
71+
check_multiplicable("mdivide_right", "b", b, "A", A);
72+
if (A.size() == 0) {
73+
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
74+
return ret_type{b.rows(), 0};
75+
}
76+
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
77+
}
7678
7779
template <typename EigMat1, typename EigMat2,
78-
require_eigen_vt<std::is_arithmetic, EigMat1>* = nullptr,
80+
require_eigen_vt<is_var_or_arithmetic, EigMat1>* = nullptr,
7981
require_eigen_vt<is_fvar, EigMat2>* = nullptr>
80-
inline Eigen::Matrix<value_type_t<EigMat2>, EigMat1::RowsAtCompileTime,
81-
EigMat2::ColsAtCompileTime>
82-
mdivide_right(const EigMat1& A, const EigMat2& b) {
83-
using T = typename value_type_t<EigMat2>::Scalar;
84-
constexpr int R1 = EigMat1::RowsAtCompileTime;
85-
constexpr int C1 = EigMat1::ColsAtCompileTime;
86-
constexpr int R2 = EigMat2::RowsAtCompileTime;
87-
constexpr int C2 = EigMat2::ColsAtCompileTime;
88-
89-
check_square("mdivide_right", "b", b);
90-
check_multiplicable("mdivide_right", "A", A, "b", b);
91-
if (b.size() == 0) {
92-
return {A.rows(), 0};
93-
}
94-
95-
auto&& b_ref = to_ref(b);
96-
Eigen::Matrix<T, R2, C2> val_b = b_ref.val();
97-
Eigen::Matrix<T, R2, C2> deriv_b = b_ref.d();
98-
99-
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(A, val_b);
100-
Eigen::Matrix<value_type_t<EigMat2>, R1, C2> res = A_mult_inv_b.template
101-
cast<value_type_t<EigMat2>>(); res.d() = -multiply(A_mult_inv_b,
102-
mdivide_right(deriv_b, val_b)); return res;
103-
}
82+
inline auto
83+
mdivide_right(const EigMat1& b, const EigMat2& A) {
84+
using T_return = return_type_t<EigMat1, EigMat2>;
85+
check_square("mdivide_right", "A", A);
86+
check_multiplicable("mdivide_right", "b", b, "A", A);
87+
if (A.size() == 0) {
88+
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
89+
return ret_type{b.rows(), 0};
90+
}
91+
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
92+
}
10493
*/
10594
} // namespace math
10695
} // namespace stan

stan/math/mix.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
#include <stan/math/opencl/rev.hpp>
1010
#endif
1111

12-
#include <stan/math/fwd/core.hpp>
13-
#include <stan/math/fwd/meta.hpp>
14-
#include <stan/math/fwd/fun.hpp>
15-
#include <stan/math/fwd/functor.hpp>
16-
1712
#include <stan/math/rev/core.hpp>
1813
#include <stan/math/rev/meta.hpp>
1914
#include <stan/math/rev/fun.hpp>
2015
#include <stan/math/rev/functor.hpp>
2116

17+
#include <stan/math/fwd/core.hpp>
18+
#include <stan/math/fwd/meta.hpp>
19+
#include <stan/math/fwd/fun.hpp>
20+
#include <stan/math/fwd/functor.hpp>
21+
2222
#include <stan/math/prim.hpp>
2323

2424
#endif

stan/math/prim/fun/mdivide_right.hpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,16 @@ namespace math {
2222
*/
2323
template <typename EigMat1, typename EigMat2,
2424
require_all_eigen_t<EigMat1, EigMat2>* = nullptr>
25-
inline Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
26-
EigMat1::RowsAtCompileTime, EigMat2::ColsAtCompileTime>
25+
inline auto
2726
mdivide_right(const EigMat1& b, const EigMat2& A) {
2827
using T_return = return_type_t<EigMat1, EigMat2>;
2928
check_square("mdivide_right", "A", A);
3029
check_multiplicable("mdivide_right", "b", b, "A", A);
3130
if (A.size() == 0) {
32-
return {b.rows(), 0};
31+
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
32+
return ret_type{b.rows(), 0};
3333
}
34-
return Eigen::Matrix<T_return, EigMat2::RowsAtCompileTime,
35-
EigMat2::ColsAtCompileTime>(A)
36-
.transpose()
37-
.lu()
38-
.solve(Eigen::Matrix<T_return, EigMat1::RowsAtCompileTime,
39-
EigMat1::ColsAtCompileTime>(b)
40-
.transpose())
41-
.transpose();
34+
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
4235
}
4336

4437
} // namespace math

stan/math/prim/fun/mdivide_right_tri.hpp

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,18 @@ mdivide_right_tri(const EigMat1& b, const EigMat2& A) {
3939
if (A.rows() == 0) {
4040
return {b.rows(), 0};
4141
}
42-
if (TriView == Eigen::Lower) {
43-
return Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
44-
EigMat2::RowsAtCompileTime,
45-
EigMat2::ColsAtCompileTime>(A)
46-
.template triangularView<TriView>()
47-
.transpose()
48-
.solve(Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
49-
EigMat1::RowsAtCompileTime,
50-
EigMat1::ColsAtCompileTime>(b)
51-
.transpose())
52-
.transpose();
53-
} else {
54-
return Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
55-
EigMat2::RowsAtCompileTime,
56-
EigMat2::ColsAtCompileTime>(A)
57-
.template triangularView<TriView>()
58-
.solve(Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
59-
EigMat1::RowsAtCompileTime,
60-
EigMat1::ColsAtCompileTime>(b)
61-
.transpose())
62-
.transpose();
63-
}
42+
43+
return Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
44+
EigMat2::RowsAtCompileTime, EigMat2::ColsAtCompileTime>(
45+
A)
46+
.template triangularView<TriView>()
47+
.transpose()
48+
.solve(
49+
Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
50+
EigMat1::RowsAtCompileTime, EigMat1::ColsAtCompileTime>(
51+
b)
52+
.transpose())
53+
.transpose();
6454
}
6555

6656
} // namespace math

test/unit/math/mix/fun/mdivide_right_test.cpp

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include <test/unit/math/test_ad.hpp>
22
#include <vector>
3-
3+
/*
44
TEST(MathMixMatFun, mdivideRightSizes) {
55
auto f = [](const auto& x, const auto& y) {
66
return stan::math::mdivide_right(x, y);
@@ -45,20 +45,68 @@ TEST(MathMixMatFun, mdivideRight) {
4545
4646
Eigen::MatrixXd e(0, 2);
4747
48-
Eigen::RowVectorXd g(2);
49-
g << 1, 2;
50-
5148
// matrix, matrix
5249
for (const auto& m1 : std::vector<Eigen::MatrixXd>{a, b, c, d, e}) {
5350
for (const auto& m2 : std::vector<Eigen::MatrixXd>{a, b, c, d}) {
5451
stan::test::expect_ad(f, m1, m2);
5552
}
5653
}
54+
}
55+
*/
56+
TEST(MathMixMatFun, mdivideRight_rowvector_matrix1) {
57+
auto f = [](const auto& x, const auto& y) {
58+
return stan::math::mdivide_right(x, y);
59+
};
60+
Eigen::MatrixXd a(2, 2);
61+
a << 2, 3, 3, 7;
62+
63+
Eigen::MatrixXd b(2, 2);
64+
b << 1, 0, 0, 1;
65+
66+
Eigen::MatrixXd c(2, 2);
67+
c << 12, 13, 15, 17;
68+
69+
Eigen::MatrixXd d(2, 2);
70+
d << 2, 3, 5, 7;
5771

72+
Eigen::MatrixXd e(0, 2);
73+
74+
Eigen::RowVectorXd g(2);
75+
g << 1, 1;
76+
77+
stan::test::expect_ad(f, g, b);
5878
// vector, matrix
79+
/*
5980
for (const auto& m : std::vector<Eigen::MatrixXd>{b}) {
60-
stan::test::expect_ad(f, g, m);
81+
stan::test::expect_ad(f, g, m);
6182
}
83+
Eigen::MatrixXd m = b;
84+
Eigen::Matrix<stan::math::var, -1, -1> m_var = b;
85+
Eigen::Matrix<stan::math::fvar<double>, -1, -1> m_fvar = b;
86+
Eigen::Matrix<stan::math::var, 1, -1> g_var = g;
87+
Eigen::Matrix<stan::math::fvar<double>, 1, -1> g_fvar = g;
88+
g_fvar.d().setOnes();
89+
Eigen::MatrixXd ans1 = f(g, m);
90+
std::cout << "\nans1: \n" << ans1 << "\n";
91+
auto ans2 = f(g_var, m);
92+
auto ans3 = f(g_fvar, m);
93+
ans2.array().sum().grad();
94+
std::cout << "\nans1 vval: \n" << ans2.val() << "\n";
95+
std::cout << "\nans1 vadj: \n" << ans2.adj() << "\n";
96+
std::cout << "\nans1 fval: \n" << ans3.val() << "\n";
97+
std::cout << "\nans1 fadj: \n" << ans3.d() << "\n";
98+
99+
auto ans4 = f(g, m_var);
100+
auto ans5 = f(g, m_fvar);
101+
auto ans6 = f(g_var, m_var);
102+
auto ans7 = f(g_fvar, m_fvar);
103+
*/
104+
}
105+
/*
106+
TEST(MathMixMatFun, mdivideRight_rowvector_matrix) {
107+
auto f = [](const auto& x, const auto& y) {
108+
return stan::math::mdivide_right(x, y);
109+
};
62110
63111
Eigen::RowVectorXd u(5);
64112
u << 62, 84, 84, 76, 108;
@@ -68,6 +116,7 @@ TEST(MathMixMatFun, mdivideRight) {
68116
stan::test::expect_ad(f, u, v);
69117
}
70118
119+
71120
TEST(MathMixMatFun, mdivideRightZeros) {
72121
auto f = [](const auto& x, const auto& y) {
73122
return stan::math::mdivide_right(x, y);
@@ -86,3 +135,4 @@ TEST(MathMixMatFun, mdivideRightZeros) {
86135
// exceptions: wrong types
87136
stan::test::expect_ad(f, v3, m33);
88137
}
138+
*/

0 commit comments

Comments
 (0)