Skip to content

Commit a04ffbf

Browse files
committed
stan changes
1 parent 192c775 commit a04ffbf

12 files changed

Lines changed: 248 additions & 109 deletions
Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
#ifndef STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP
22
#define STAN_MATH_FWD_FUN_MDIVIDE_RIGHT_HPP
33

4+
#include <stan/math/fwd/core.hpp>
5+
#include <stan/math/fwd/fun/multiply.hpp>
6+
#include <stan/math/fwd/fun/to_fvar.hpp>
47
#include <stan/math/prim/err.hpp>
58
#include <stan/math/prim/fun/Eigen.hpp>
69
#include <stan/math/prim/fun/mdivide_right.hpp>
710
#include <stan/math/prim/fun/multiply.hpp>
8-
#include <stan/math/fwd/core.hpp>
9-
#include <stan/math/fwd/fun/multiply.hpp>
10-
#include <stan/math/fwd/fun/to_fvar.hpp>
11+
#include <stan/math/prim/fun/subtract.hpp>
1112
#include <vector>
1213

1314
namespace stan {
1415
namespace math {
15-
16+
/*
1617
template <typename EigMat1, typename EigMat2,
17-
require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr,
18-
require_vt_same<EigMat1, EigMat2>* = nullptr>
19-
inline Eigen::Matrix<value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
20-
EigMat2::ColsAtCompileTime>
21-
mdivide_right(const EigMat1& A, const EigMat2& b) {
22-
using T = typename value_type_t<EigMat1>::Scalar;
18+
require_all_eigen_vt<is_fvar, EigMat1, EigMat2>* = nullptr>
19+
inline auto mdivide_right(const EigMat1& A, const EigMat2& b) {
20+
using T1 = typename return_type_t<EigMat1>::Scalar;
21+
using T2 = typename return_type_t<EigMat2>::Scalar;
22+
using ret_scalar = return_type_t<EigMat1, EigMat2>;
2323
constexpr int R1 = EigMat1::RowsAtCompileTime;
2424
constexpr int C1 = EigMat1::ColsAtCompileTime;
2525
constexpr int R2 = EigMat2::RowsAtCompileTime;
@@ -28,35 +28,24 @@ mdivide_right(const EigMat1& A, const EigMat2& b) {
2828
check_square("mdivide_right", "b", b);
2929
check_multiplicable("mdivide_right", "A", A, "b", b);
3030
if (b.size() == 0) {
31-
return {A.rows(), 0};
31+
return Eigen::Matrix<ret_scalar, R1, C2>(A.rows(), 0);
3232
}
3333
34-
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
35-
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
36-
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
37-
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
38-
39-
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
40-
for (int j = 0; j < A.cols(); j++) {
41-
for (int i = 0; i < A.rows(); i++) {
42-
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
43-
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
44-
}
45-
}
4634
47-
const Eigen::Ref<const plain_type_t<EigMat2>>& b_ref = b;
48-
for (int j = 0; j < b.cols(); j++) {
49-
for (int i = 0; i < b.rows(); i++) {
50-
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
51-
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
52-
}
53-
}
35+
auto&& A_ref = to_ref(A);
36+
Eigen::Matrix<T1, R1, C1> val_A = A_ref.val();
37+
Eigen::Matrix<T1, R1, C1> deriv_A = A_ref.d();
5438
55-
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(val_A, val_b);
39+
auto&& b_ref = to_ref(b);
40+
Eigen::Matrix<T2, R2, C2> val_b = b_ref.val();
41+
Eigen::Matrix<T2, R2, C2> deriv_b = b_ref.d();
5642
57-
return to_fvar(A_mult_inv_b,
58-
mdivide_right(deriv_A, val_b)
59-
- A_mult_inv_b * mdivide_right(deriv_b, val_b));
43+
Eigen::Matrix<return_type_t<T1, T2>, R1, C2> A_mult_inv_b =
44+
mdivide_right(val_A, val_b).template cast<return_type_t<T1, T2>>().eval();
45+
Eigen::Matrix<ret_scalar, R1, C2> res = A_mult_inv_b.template
46+
cast<ret_scalar>().eval(); res.d() = subtract(mdivide_right(deriv_A, val_b),
47+
multiply(A_mult_inv_b, mdivide_right(deriv_b, val_b)));
48+
return res;
6049
}
6150
6251
template <typename EigMat1, typename EigMat2,
@@ -68,25 +57,21 @@ mdivide_right(const EigMat1& A, const EigMat2& b) {
6857
using T = typename value_type_t<EigMat1>::Scalar;
6958
constexpr int R1 = EigMat1::RowsAtCompileTime;
7059
constexpr int C1 = EigMat1::ColsAtCompileTime;
60+
constexpr int C2 = EigMat2::ColsAtCompileTime;
7161
7262
check_square("mdivide_right", "b", b);
7363
check_multiplicable("mdivide_right", "A", A, "b", b);
7464
if (b.size() == 0) {
7565
return {A.rows(), 0};
7666
}
7767
78-
Eigen::Matrix<T, R1, C1> val_A(A.rows(), A.cols());
79-
Eigen::Matrix<T, R1, C1> deriv_A(A.rows(), A.cols());
8068
81-
const Eigen::Ref<const plain_type_t<EigMat1>>& A_ref = A;
82-
for (int j = 0; j < A.cols(); j++) {
83-
for (int i = 0; i < A.rows(); i++) {
84-
val_A.coeffRef(i, j) = A_ref.coeff(i, j).val_;
85-
deriv_A.coeffRef(i, j) = A_ref.coeff(i, j).d_;
86-
}
87-
}
88-
89-
return to_fvar(mdivide_right(val_A, b), mdivide_right(deriv_A, b));
69+
auto&& A_ref = to_ref(A);
70+
Eigen::Matrix<T, R1, C1> val_A = A_ref.val();
71+
Eigen::Matrix<T, R1, C1> deriv_A = A_ref.d();
72+
Eigen::Matrix<value_type_t<EigMat1>, R1, C2> res = mdivide_right(val_A,
73+
b).template cast<value_type_t<EigMat1>>(); res.d() = mdivide_right(deriv_A, b);
74+
return res;
9075
}
9176
9277
template <typename EigMat1, typename EigMat2,
@@ -107,22 +92,16 @@ mdivide_right(const EigMat1& A, const EigMat2& b) {
10792
return {A.rows(), 0};
10893
}
10994
110-
Eigen::Matrix<T, R2, C2> val_b(b.rows(), b.cols());
111-
Eigen::Matrix<T, R2, C2> deriv_b(b.rows(), b.cols());
112-
113-
const Eigen::Ref<const plain_type_t<EigMat2>>& b_ref = b;
114-
for (int j = 0; j < b.cols(); j++) {
115-
for (int i = 0; i < b.rows(); i++) {
116-
val_b.coeffRef(i, j) = b_ref.coeff(i, j).val_;
117-
deriv_b.coeffRef(i, j) = b_ref.coeff(i, j).d_;
118-
}
119-
}
95+
auto&& b_ref = to_ref(b);
96+
Eigen::Matrix<T, R2, C2> val_b = b_ref.val();
97+
Eigen::Matrix<T, R2, C2> deriv_b = b_ref.d();
12098
12199
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(A, val_b);
122-
123-
return to_fvar(A_mult_inv_b, -A_mult_inv_b * mdivide_right(deriv_b, val_b));
100+
Eigen::Matrix<value_type_t<EigMat2>, R1, C2> res = A_mult_inv_b.template
101+
cast<value_type_t<EigMat2>>(); res.d() = -multiply(A_mult_inv_b,
102+
mdivide_right(deriv_b, val_b)); return res;
124103
}
125-
104+
*/
126105
} // namespace math
127106
} // namespace stan
128107
#endif

stan/math/prim/fun/eigenvalues.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77
namespace stan {
88
namespace math {
99

10-
template <typename T>
11-
Eigen::Matrix<std::complex<T>, -1, 1> eigenvalues(
12-
const Eigen::Matrix<T, -1, -1>& m) {
10+
template <typename Mat, require_eigen_t<Mat>* = nullptr>
11+
inline auto eigenvalues(const Mat& m) {
1312
check_nonzero_size("eigenvalues", "m", m);
1413
check_square("eigenvalues", "m", m);
15-
16-
Eigen::EigenSolver<Eigen::Matrix<T, -1, -1>> solver(m);
17-
return solver.eigenvalues();
14+
return m.eigenvalues().eval();
1815
}
1916

2017
} // namespace math

stan/math/prim/fun/mdivide_left_tri.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,24 @@ namespace math {
2626
template <Eigen::UpLoType TriView, typename T1, typename T2,
2727
require_all_eigen_t<T1, T2> * = nullptr,
2828
require_all_not_eigen_vt<is_var, T1, T2> * = nullptr>
29-
inline Eigen::Matrix<return_type_t<T1, T2>, T1::RowsAtCompileTime,
30-
T2::ColsAtCompileTime>
31-
mdivide_left_tri(const T1 &A, const T2 &b) {
29+
inline auto mdivide_left_tri(const T1 &A, const T2 &b) {
3230
using T_return = return_type_t<T1, T2>;
3331
check_square("mdivide_left_tri", "A", A);
3432
check_multiplicable("mdivide_left_tri", "A", A, "b", b);
33+
using ret_type = decltype(A.template cast<T_return>()
34+
.eval()
35+
.template triangularView<TriView>()
36+
.solve(b.template cast<T_return>().eval())
37+
.eval());
3538
if (A.rows() == 0) {
36-
return {0, b.cols()};
39+
return ret_type(0, b.cols());
3740
}
3841

3942
return A.template cast<T_return>()
4043
.eval()
4144
.template triangularView<TriView>()
42-
.solve(b.template cast<T_return>().eval());
45+
.solve(b.template cast<T_return>().eval())
46+
.eval();
4347
}
4448

4549
/**

stan/math/prim/fun/mdivide_right.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ namespace math {
2121
* match the size of A.
2222
*/
2323
template <typename EigMat1, typename EigMat2,
24-
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
25-
require_all_not_vt_fvar<EigMat1, EigMat2>* = nullptr>
24+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr>
2625
inline Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
2726
EigMat1::RowsAtCompileTime, EigMat2::ColsAtCompileTime>
2827
mdivide_right(const EigMat1& b, const EigMat2& A) {
@@ -32,7 +31,6 @@ mdivide_right(const EigMat1& b, const EigMat2& A) {
3231
if (A.size() == 0) {
3332
return {b.rows(), 0};
3433
}
35-
3634
return Eigen::Matrix<T_return, EigMat2::RowsAtCompileTime,
3735
EigMat2::ColsAtCompileTime>(A)
3836
.transpose()

stan/math/prim/fun/mdivide_right_tri.hpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,28 @@ mdivide_right_tri(const EigMat1& b, const EigMat2& A) {
3939
if (A.rows() == 0) {
4040
return {b.rows(), 0};
4141
}
42-
43-
return Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
44-
EigMat2::RowsAtCompileTime, EigMat2::ColsAtCompileTime>(
45-
A)
46-
.template triangularView<TriView>()
47-
.transpose()
48-
.solve(
49-
Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
50-
EigMat1::RowsAtCompileTime, EigMat1::ColsAtCompileTime>(
51-
b)
52-
.transpose())
53-
.transpose();
42+
if (TriView == Eigen::Lower) {
43+
return Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
44+
EigMat2::RowsAtCompileTime,
45+
EigMat2::ColsAtCompileTime>(A)
46+
.template triangularView<TriView>()
47+
.transpose()
48+
.solve(Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
49+
EigMat1::RowsAtCompileTime,
50+
EigMat1::ColsAtCompileTime>(b)
51+
.transpose())
52+
.transpose();
53+
} else {
54+
return Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
55+
EigMat2::RowsAtCompileTime,
56+
EigMat2::ColsAtCompileTime>(A)
57+
.template triangularView<TriView>()
58+
.solve(Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
59+
EigMat1::RowsAtCompileTime,
60+
EigMat1::ColsAtCompileTime>(b)
61+
.transpose())
62+
.transpose();
63+
}
5464
}
5565

5666
} // namespace math

stan/math/prim/meta/plain_type.hpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/meta/is_eigen.hpp>
55
#include <stan/math/prim/meta/is_detected.hpp>
6+
#include <stan/math/prim/meta/is_var_matrix.hpp>
67
#include <type_traits>
78

89
namespace stan {
@@ -40,13 +41,43 @@ struct eval_return_type {
4041
template <typename T>
4142
using eval_return_type_t = typename eval_return_type<T>::type;
4243

44+
namespace internal {
45+
// primary template handles types that have no nested ::type member:
46+
template <class, class = void>
47+
struct has_plain_object : std::false_type {};
48+
49+
// specialization recognizes types that do have a nested ::type member:
50+
template <class T>
51+
struct has_plain_object<T, void_t<typename std::decay_t<T>::PlainObject>>
52+
: std::true_type {};
53+
54+
// primary template handles types that have no nested ::type member:
55+
template <class, class = void>
56+
struct has_eval : std::false_type {};
57+
58+
// specialization recognizes types that do have a nested ::type member:
59+
template <class T>
60+
struct has_eval<T, void_t<decltype(std::declval<std::decay_t<T>&>().eval())>>
61+
: std::true_type {};
62+
63+
} // namespace internal
64+
4365
/**
4466
* Determines plain (non expression) type associated with \c T. For \c Eigen
4567
* expression it is a type the expression can be evaluated into.
4668
* @tparam T type to determine plain type of
4769
*/
4870
template <typename T>
49-
struct plain_type<T, require_eigen_t<T>> {
71+
struct plain_type<T, require_t<bool_constant<internal::has_eval<T>::value
72+
&& is_eigen<T>::value>>> {
73+
using type = std::decay_t<decltype(std::declval<T&>().eval())>;
74+
};
75+
76+
template <typename T>
77+
struct plain_type<
78+
T, require_t<bool_constant<!internal::has_eval<T>::value
79+
&& internal::has_plain_object<T>::value
80+
&& is_eigen<T>::value>>> {
5081
using type = typename std::decay_t<T>::PlainObject;
5182
};
5283

stan/math/rev/core/Eigen_NumTraits.hpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,41 @@ struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
354354
}
355355
};
356356

357+
template <typename Index, typename LhsMapper, bool ConjugateLhs,
358+
bool ConjugateRhs, typename RhsMapper, int Version>
359+
struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
360+
ColMajor, ConjugateLhs, double, RhsMapper,
361+
ConjugateRhs, Version> {
362+
using LhsScalar = stan::math::var;
363+
using RhsScalar = double;
364+
using ResScalar = stan::math::var;
365+
enum { LhsStorageOrder = ColMajor };
366+
367+
EIGEN_DONT_INLINE static void run(Index rows, Index cols,
368+
const LhsMapper& lhsMapper,
369+
const RhsMapper& rhsMapper, ResScalar* res,
370+
Index resIncr, const ResScalar& alpha) {
371+
const LhsScalar* lhs = lhsMapper.data();
372+
const Index lhsStride = lhsMapper.stride();
373+
const RhsScalar* rhs = rhsMapper.data();
374+
const Index rhsIncr = rhsMapper.stride();
375+
run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
376+
}
377+
378+
EIGEN_DONT_INLINE static void run(Index rows, Index cols,
379+
const LhsScalar* lhs, Index lhsStride,
380+
const RhsScalar* rhs, Index rhsIncr,
381+
ResScalar* res, Index resIncr,
382+
const ResScalar& alpha) {
383+
using stan::math::gevv_vvv_vari;
384+
using stan::math::var;
385+
for (Index i = 0; i < rows; ++i) {
386+
res[i * resIncr] += var(
387+
new gevv_vvv_vari(&alpha, &lhs[i], lhsStride, rhs, rhsIncr, cols));
388+
}
389+
}
390+
};
391+
357392
template <typename Index, typename LhsMapper, bool ConjugateLhs,
358393
bool ConjugateRhs, typename RhsMapper, int Version>
359394
struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
@@ -394,6 +429,46 @@ struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
394429
}
395430
};
396431

432+
template <typename Index, typename LhsMapper, bool ConjugateLhs,
433+
bool ConjugateRhs, typename RhsMapper, int Version>
434+
struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
435+
RowMajor, ConjugateLhs, double, RhsMapper,
436+
ConjugateRhs, Version> {
437+
using LhsScalar = stan::math::var;
438+
using RhsScalar = double;
439+
using ResScalar = stan::math::var;
440+
enum { LhsStorageOrder = RowMajor };
441+
442+
EIGEN_DONT_INLINE static void run(Index rows, Index cols,
443+
const LhsMapper& lhsMapper,
444+
const RhsMapper& rhsMapper, ResScalar* res,
445+
Index resIncr, const RhsScalar& alpha) {
446+
const LhsScalar* lhs = lhsMapper.data();
447+
const Index lhsStride = lhsMapper.stride();
448+
const RhsScalar* rhs = rhsMapper.data();
449+
const Index rhsIncr = rhsMapper.stride();
450+
run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
451+
}
452+
453+
EIGEN_DONT_INLINE static void run(Index rows, Index cols,
454+
const LhsScalar* lhs, Index lhsStride,
455+
const RhsScalar* rhs, Index rhsIncr,
456+
ResScalar* res, Index resIncr,
457+
const RhsScalar& alpha) {
458+
for (Index i = 0; i < rows; i++) {
459+
res[i * resIncr] += stan::math::var(new stan::math::gevv_vvv_vari(
460+
&alpha,
461+
(static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
462+
? (&lhs[i])
463+
: (&lhs[i * lhsStride]),
464+
(static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
465+
? (lhsStride)
466+
: (1),
467+
rhs, rhsIncr, cols));
468+
}
469+
}
470+
};
471+
397472
#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
398473
template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
399474
int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride>

0 commit comments

Comments
 (0)