Skip to content

Commit fa9454b

Browse files
committed
Adapt Stan Math for Eigen 5.0.1 compatibility
1 parent 78ba45a commit fa9454b

File tree

82 files changed

+899
-498
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+899
-498
lines changed

doxygen/doxygen.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2292,7 +2292,7 @@ SEARCH_INCLUDES = YES
22922292

22932293
INCLUDE_PATH = ./ \
22942294
lib/tbb_2020.3/include \
2295-
lib/eigen_3.4.0 \
2295+
lib/eigen_5.0.1 \
22962296
lib/sundials_6.1.1/include \
22972297
lib/sundials_6.1.1/src/sundials \
22982298
lib/opencl_3.0.0 \

make/libraries

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ endef
99

1010
MATH ?=
1111
BOOST ?= $(MATH)lib/boost_1.87.0
12-
EIGEN ?= $(MATH)lib/eigen_3.4.0
12+
EIGEN ?= $(MATH)lib/eigen_5.0.1
1313
OPENCL ?= $(MATH)lib/opencl_3.0.0
1414
TBB ?= $(MATH)lib/tbb_2020.3
1515
SUNDIALS ?= $(MATH)lib/sundials_6.1.1

stan/math/fwd/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <stan/math/fwd/fun/fabs.hpp>
3434
#include <stan/math/fwd/fun/falling_factorial.hpp>
3535
#include <stan/math/fwd/fun/fdim.hpp>
36+
#include <stan/math/fwd/fun/fft.hpp>
3637
#include <stan/math/fwd/fun/floor.hpp>
3738
#include <stan/math/fwd/fun/fma.hpp>
3839
#include <stan/math/fwd/fun/fmax.hpp>

stan/math/fwd/fun/fft.hpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#ifndef STAN_MATH_FWD_FUN_FFT_HPP
2+
#define STAN_MATH_FWD_FUN_FFT_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/fun/fft.hpp>
6+
#include <stan/math/prim/fun/to_ref.hpp>
7+
#include <stan/math/prim/meta.hpp>
8+
#include <stan/math/fwd/core.hpp>
9+
#include <stan/math/fwd/fun/value_of.hpp>
10+
#include <stan/math/fwd/meta.hpp>
11+
#include <complex>
12+
13+
namespace stan {
14+
namespace math {
15+
16+
/**
17+
* Return the discrete Fourier transform of the specified complex
18+
* vector for forward-mode autodiff.
19+
*
20+
* @tparam V type of complex vector argument
21+
* @param[in] x vector to transform
22+
* @return discrete Fourier transform of `x`
23+
*/
24+
template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
25+
require_fvar_t<base_type_t<value_type_t<V>>>* = nullptr>
26+
inline Eigen::Matrix<scalar_type_t<V>, -1, 1> fft(V&& x) {
27+
using scalar_t = scalar_type_t<V>;
28+
using fvar_t = base_type_t<scalar_t>;
29+
using complex_t = std::complex<partials_type_t<fvar_t>>;
30+
decltype(auto) x_ref = to_ref(std::forward<V>(x));
31+
if (x_ref.size() <= 1) {
32+
return Eigen::Matrix<scalar_type_t<V>, -1, 1>(x_ref);
33+
}
34+
35+
Eigen::Matrix<complex_t, -1, 1> x_val = value_of(x_ref);
36+
Eigen::Matrix<complex_t, -1, 1> x_d = x_ref.unaryExpr(
37+
[](const auto& z) { return complex_t(z.real().d(), z.imag().d()); });
38+
39+
auto y_val = fft(std::move(x_val));
40+
auto y_d = fft(std::move(x_d));
41+
42+
using out_t = Eigen::Matrix<scalar_type_t<V>, -1, 1>;
43+
out_t y
44+
= y_val.binaryExpr(y_d, [](const complex_t& val, const complex_t& der) {
45+
return std::complex<fvar_t>{fvar_t(val.real(), der.real()),
46+
fvar_t(val.imag(), der.imag())};
47+
});
48+
return y;
49+
}
50+
51+
/**
52+
* Return the inverse discrete Fourier transform of the specified
53+
* complex vector for forward-mode autodiff.
54+
*
55+
* @tparam V type of complex vector argument
56+
* @param[in] y vector to inverse transform
57+
* @return inverse discrete Fourier transform of `y`
58+
*/
59+
template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
60+
require_fvar_t<base_type_t<value_type_t<V>>>* = nullptr>
61+
inline Eigen::Matrix<scalar_type_t<V>, -1, 1> inv_fft(V&& y) {
62+
using scalar_t = scalar_type_t<V>;
63+
using fvar_t = base_type_t<scalar_t>;
64+
using complex_t = std::complex<partials_type_t<fvar_t>>;
65+
decltype(auto) y_ref = to_ref(std::forward<V>(y));
66+
if (y_ref.size() <= 1) {
67+
return Eigen::Matrix<scalar_type_t<V>, -1, 1>(y_ref);
68+
}
69+
70+
Eigen::Matrix<complex_t, -1, 1> y_val = value_of(y_ref);
71+
Eigen::Matrix<complex_t, -1, 1> y_d = y_ref.unaryExpr(
72+
[](const auto& z) { return complex_t(z.real().d(), z.imag().d()); });
73+
74+
auto x_val = inv_fft(std::move(y_val));
75+
auto x_d = inv_fft(std::move(y_d));
76+
77+
using out_t = Eigen::Matrix<scalar_type_t<V>, -1, 1>;
78+
out_t x
79+
= x_val.binaryExpr(x_d, [](const complex_t& val, const complex_t& der) {
80+
return std::complex<fvar_t>{fvar_t(val.real(), der.real()),
81+
fvar_t(val.imag(), der.imag())};
82+
});
83+
return x;
84+
}
85+
86+
/**
87+
* Return the two-dimensional discrete Fourier transform of the
88+
* specified complex matrix for forward-mode autodiff.
89+
*
90+
* @tparam M type of complex matrix argument
91+
* @param[in] x matrix to transform
92+
* @return discrete 2D Fourier transform of `x`
93+
*/
94+
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
95+
require_fvar_t<base_type_t<value_type_t<M>>>* = nullptr>
96+
inline Eigen::Matrix<scalar_type_t<M>, -1, -1> fft2(M&& x) {
97+
using scalar_t = scalar_type_t<M>;
98+
using fvar_t = base_type_t<scalar_t>;
99+
using complex_t = std::complex<partials_type_t<fvar_t>>;
100+
decltype(auto) x_ref = to_ref(std::forward<M>(x));
101+
Eigen::Matrix<complex_t, -1, -1> x_val = value_of(x_ref);
102+
Eigen::Matrix<complex_t, -1, -1> x_d = x_ref.unaryExpr(
103+
[](const auto& z) { return complex_t(z.real().d(), z.imag().d()); });
104+
105+
auto y_val = fft2(std::move(x_val));
106+
auto y_d = fft2(std::move(x_d));
107+
108+
using out_t = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
109+
out_t y
110+
= y_val.binaryExpr(y_d, [](const complex_t& val, const complex_t& der) {
111+
return std::complex<fvar_t>{fvar_t(val.real(), der.real()),
112+
fvar_t(val.imag(), der.imag())};
113+
});
114+
return y;
115+
}
116+
117+
/**
118+
* Return the two-dimensional inverse discrete Fourier transform of
119+
* the specified complex matrix for forward-mode autodiff.
120+
*
121+
* @tparam M type of complex matrix argument
122+
* @param[in] y matrix to inverse transform
123+
* @return inverse discrete 2D Fourier transform of `y`
124+
*/
125+
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
126+
require_fvar_t<base_type_t<value_type_t<M>>>* = nullptr>
127+
inline Eigen::Matrix<scalar_type_t<M>, -1, -1> inv_fft2(M&& y) {
128+
using scalar_t = scalar_type_t<M>;
129+
using fvar_t = base_type_t<scalar_t>;
130+
using complex_t = std::complex<partials_type_t<fvar_t>>;
131+
decltype(auto) y_ref = to_ref(std::forward<M>(y));
132+
Eigen::Matrix<complex_t, -1, -1> y_val = value_of(y_ref);
133+
Eigen::Matrix<complex_t, -1, -1> y_d = y_ref.unaryExpr(
134+
[](const auto& z) { return complex_t(z.real().d(), z.imag().d()); });
135+
136+
auto x_val = inv_fft2(std::move(y_val));
137+
auto x_d = inv_fft2(std::move(y_d));
138+
139+
using out_t = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
140+
out_t x
141+
= x_val.binaryExpr(x_d, [](const complex_t& val, const complex_t& der) {
142+
return std::complex<fvar_t>{fvar_t(val.real(), der.real()),
143+
fvar_t(val.imag(), der.imag())};
144+
});
145+
return x;
146+
}
147+
148+
} // namespace math
149+
} // namespace stan
150+
151+
#endif

stan/math/fwd/fun/mdivide_left_tri_low.hpp

Lines changed: 41 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,109 +9,77 @@
99
#include <stan/math/fwd/fun/multiply.hpp>
1010
#include <stan/math/prim/err.hpp>
1111
#include <stan/math/prim/fun/mdivide_left_tri_low.hpp>
12-
12+
#include <stan/math/prim/fun/eval.hpp>
13+
#include <stan/math/prim/fun/subtract.hpp>
1314
namespace stan {
1415
namespace math {
1516

1617
template <typename T1, typename T2,
1718
require_all_eigen_vt<is_fvar, T1, T2>* = nullptr,
1819
require_vt_same<T1, T2>* = nullptr>
19-
inline Eigen::Matrix<value_type_t<T1>, T1::RowsAtCompileTime,
20-
T2::ColsAtCompileTime>
21-
mdivide_left_tri_low(const T1& A, const T2& b) {
22-
using T = typename value_type_t<T1>::Scalar;
23-
constexpr int S1 = T1::RowsAtCompileTime;
24-
constexpr int C2 = T2::ColsAtCompileTime;
20+
inline Eigen::Matrix<value_type_t<T1>, std::decay_t<T1>::RowsAtCompileTime,
21+
std::decay_t<T2>::ColsAtCompileTime>
22+
mdivide_left_tri_low(T1&& A, T2&& b) {
23+
constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime;
24+
constexpr int C2 = std::decay_t<T2>::ColsAtCompileTime;
2525

2626
check_square("mdivide_left_tri_low", "A", A);
2727
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
2828
if (A.size() == 0) {
2929
return {0, b.cols()};
3030
}
31-
32-
Eigen::Matrix<T, S1, S1> val_A(A.rows(), A.cols());
33-
Eigen::Matrix<T, S1, S1> deriv_A(A.rows(), A.cols());
34-
val_A.setZero();
35-
deriv_A.setZero();
36-
37-
const Eigen::Ref<const plain_type_t<T2>>& b_ref = b;
38-
const Eigen::Ref<const plain_type_t<T1>>& A_ref = A;
39-
for (size_type j = 0; j < A.cols(); j++) {
40-
for (size_type i = j; i < A.rows(); i++) {
41-
val_A(i, j) = A_ref(i, j).val_;
42-
deriv_A(i, j) = A_ref(i, j).d_;
43-
}
44-
}
45-
46-
Eigen::Matrix<T, S1, C2> inv_A_mult_b = mdivide_left(val_A, b_ref.val());
47-
48-
return to_fvar(inv_A_mult_b,
49-
mdivide_left(val_A, b_ref.d())
50-
- multiply(mdivide_left(val_A, deriv_A), inv_A_mult_b));
31+
decltype(auto) b_ref = to_ref(std::forward<T2>(b));
32+
decltype(auto) A_ref = to_ref(std::forward<T1>(A));
33+
auto inv_A_mult_b
34+
= eval(mdivide_left_tri<Eigen::Lower>(A_ref.val(), b_ref.val()));
35+
return to_fvar(
36+
inv_A_mult_b,
37+
subtract(mdivide_left_tri<Eigen::Lower>(A_ref.val(), b_ref.d()),
38+
multiply(mdivide_left_tri<Eigen::Lower>(
39+
A_ref.val(),
40+
A_ref.d().template triangularView<Eigen::Lower>()),
41+
inv_A_mult_b)));
5142
}
5243

5344
template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
5445
require_vt_same<double, T1>* = nullptr,
5546
require_eigen_vt<is_fvar, T2>* = nullptr>
56-
inline Eigen::Matrix<value_type_t<T2>, T1::RowsAtCompileTime,
57-
T2::ColsAtCompileTime>
58-
mdivide_left_tri_low(const T1& A, const T2& b) {
59-
constexpr int S1 = T1::RowsAtCompileTime;
60-
47+
inline Eigen::Matrix<value_type_t<T2>, std::decay_t<T1>::RowsAtCompileTime,
48+
std::decay_t<T2>::ColsAtCompileTime>
49+
mdivide_left_tri_low(T1&& A, T2&& b) {
50+
constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime;
6151
check_square("mdivide_left_tri_low", "A", A);
6252
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
6353
if (A.size() == 0) {
6454
return {0, b.cols()};
6555
}
66-
67-
Eigen::Matrix<double, S1, S1> val_A(A.rows(), A.cols());
68-
val_A.setZero();
69-
70-
const Eigen::Ref<const plain_type_t<T2>>& b_ref = b;
71-
const Eigen::Ref<const plain_type_t<T1>>& A_ref = A;
72-
for (size_type j = 0; j < A.cols(); j++) {
73-
for (size_type i = j; i < A.rows(); i++) {
74-
val_A(i, j) = A_ref(i, j);
75-
}
76-
}
77-
78-
return to_fvar(mdivide_left(val_A, b_ref.val()),
79-
mdivide_left(val_A, b_ref.d()));
56+
decltype(auto) A_ref = to_ref(std::forward<T1>(A));
57+
decltype(auto) b_ref = to_ref(std::forward<T2>(b));
58+
return to_fvar(mdivide_left_tri<Eigen::Lower>(A_ref, b_ref.val()),
59+
mdivide_left_tri<Eigen::Lower>(A_ref, b_ref.d()));
8060
}
8161

8262
template <typename T1, typename T2, require_eigen_vt<is_fvar, T1>* = nullptr,
83-
require_eigen_t<T2>* = nullptr,
84-
require_vt_same<double, T2>* = nullptr>
85-
inline Eigen::Matrix<value_type_t<T1>, T1::RowsAtCompileTime,
86-
T2::ColsAtCompileTime>
87-
mdivide_left_tri_low(const T1& A, const T2& b) {
88-
using T = typename value_type_t<T1>::Scalar;
89-
constexpr int S1 = T1::RowsAtCompileTime;
90-
constexpr int C2 = T2::ColsAtCompileTime;
91-
63+
require_eigen_vt<std::is_floating_point, T2>* = nullptr>
64+
inline Eigen::Matrix<value_type_t<T1>, std::decay_t<T1>::RowsAtCompileTime,
65+
std::decay_t<T2>::ColsAtCompileTime>
66+
mdivide_left_tri_low(T1&& A, T2&& b) {
67+
constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime;
68+
constexpr int C2 = std::decay_t<T2>::ColsAtCompileTime;
9269
check_square("mdivide_left_tri_low", "A", A);
9370
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
9471
if (A.size() == 0) {
9572
return {0, b.cols()};
9673
}
97-
98-
Eigen::Matrix<T, S1, S1> val_A(A.rows(), A.cols());
99-
Eigen::Matrix<T, S1, S1> deriv_A(A.rows(), A.cols());
100-
val_A.setZero();
101-
deriv_A.setZero();
102-
103-
const Eigen::Ref<const plain_type_t<T1>>& A_ref = A;
104-
for (size_type j = 0; j < A.cols(); j++) {
105-
for (size_type i = j; i < A.rows(); i++) {
106-
val_A(i, j) = A_ref(i, j).val_;
107-
deriv_A(i, j) = A_ref(i, j).d_;
108-
}
109-
}
110-
111-
Eigen::Matrix<T, S1, C2> inv_A_mult_b = mdivide_left(val_A, b);
112-
113-
return to_fvar(inv_A_mult_b,
114-
-multiply(mdivide_left(val_A, deriv_A), inv_A_mult_b));
74+
decltype(auto) A_ref = to_ref(std::forward<T1>(A));
75+
auto inv_A_mult_b
76+
= eval(mdivide_left_tri<Eigen::Lower>(A_ref.val(), std::forward<T2>(b)));
77+
return to_fvar(
78+
inv_A_mult_b,
79+
-multiply(
80+
mdivide_left_tri<Eigen::Lower>(
81+
A_ref.val(), A_ref.d().template triangularView<Eigen::Lower>()),
82+
inv_A_mult_b));
11583
}
11684

11785
} // namespace math

stan/math/fwd/fun/mdivide_right.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ mdivide_right(const EigMat1& A, const EigMat2& b) {
118118

119119
Eigen::Matrix<T, R1, C2> A_mult_inv_b = mdivide_right(A, val_b);
120120

121-
return to_fvar(A_mult_inv_b, -A_mult_inv_b * mdivide_right(deriv_b, val_b));
121+
return to_fvar(A_mult_inv_b,
122+
multiply(-A_mult_inv_b, mdivide_right(deriv_b, val_b)));
122123
}
123124

124125
} // namespace math

0 commit comments

Comments
 (0)