Skip to content

Commit e881e77

Browse files
committed
update expr tests to support running only prim and reverse mode
1 parent 35db9d9 commit e881e77

2 files changed

Lines changed: 31 additions & 27 deletions

File tree

test/unit/math/expr_tests.hpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ constexpr const char*
7171
* the associated array will be zero.
7272
*/
7373
template <typename ScalarType, std::size_t N>
74-
void expect_all_used_only_once(std::array<int, N>& arg_evals,
74+
inline void expect_all_used_only_once(std::array<int, N>& arg_evals,
7575
std::array<int, N>& size_of_arg) {
7676
for (int i = 0; i < N; ++i) {
7777
EXPECT_LE(arg_evals[i], size_of_arg[i])
@@ -105,7 +105,7 @@ void expect_all_used_only_once(std::array<int, N>& arg_evals,
105105
* @param arg A type derived from `Eigen::EigenBase`
106106
*/
107107
template <typename EigMat, stan::require_eigen_t<EigMat>* = nullptr>
108-
auto make_expr(int& count, EigMat&& arg) {
108+
inline auto make_expr(int& count, EigMat&& arg) {
109109
return arg.unaryExpr(
110110
stan::test::counterOp<stan::scalar_type_t<EigMat>>(&count));
111111
}
@@ -115,7 +115,7 @@ auto make_expr(int& count, EigMat&& arg) {
115115
* @tparam Any type not derived from `Eigen::EigenBase`
116116
*/
117117
template <typename T, stan::require_not_eigen_t<T>* = nullptr>
118-
auto make_expr(int& /* count */, T&& arg) {
118+
inline constexpr auto make_expr(int& /* count */, T&& arg) {
119119
return arg;
120120
}
121121

@@ -129,7 +129,7 @@ auto make_expr(int& /* count */, T&& arg) {
129129
* `counterOp`.
130130
*/
131131
template <std::size_t N, typename... Args>
132-
auto make_expr_args(std::array<int, N>& expr_evals, Args&&... args) {
132+
inline constexpr auto make_expr_args(std::array<int, N>& expr_evals, Args&&... args) {
133133
return stan::math::index_apply<N>([&expr_evals, &args...](auto... Is) {
134134
return std::make_tuple(make_expr(expr_evals[Is], args)...);
135135
});
@@ -148,7 +148,7 @@ inline constexpr int eigen_size(T&& x) {
148148
* @tparam EigMat A type derived from `Eigen::EigenBase`
149149
*/
150150
template <typename EigMat, stan::require_eigen_t<EigMat>* = nullptr>
151-
inline int eigen_size(EigMat&& x) {
151+
inline constexpr Eigen::Index eigen_size(EigMat&& x) {
152152
return x.size();
153153
}
154154

@@ -161,7 +161,7 @@ inline int eigen_size(EigMat&& x) {
161161
* an Eigen type then the value is be zero.
162162
*/
163163
template <typename... Args>
164-
std::array<int, sizeof...(Args)> eigen_arg_sizes(Args&&... args) {
164+
inline constexpr std::array<int, sizeof...(Args)> eigen_arg_sizes(Args&&... args) {
165165
return std::array<int, sizeof...(Args)>{eigen_size(args)...};
166166
}
167167

@@ -170,7 +170,7 @@ std::array<int, sizeof...(Args)> eigen_arg_sizes(Args&&... args) {
170170
*/
171171
template <typename ScalarType, typename F, typename... Args,
172172
require_all_not_eigen_t<Args...>* = nullptr>
173-
void check_expr_test(F&& f, Args&&... args) {}
173+
inline constexpr void check_expr_test(F&& f, Args&&... args) {}
174174

175175
/**
176176
* Check whether any Eigen inputs are executed too many times.
@@ -200,7 +200,7 @@ void check_expr_test(F&& f, Args&&... args) {}
200200
*/
201201
template <typename ScalarType, typename F, typename... Args,
202202
require_any_eigen_t<Args...>* = nullptr>
203-
void check_expr_test(F&& f, Args&&... args) {
203+
inline void check_expr_test(F&& f, Args&&... args) {
204204
std::array<int, sizeof...(args)> expr_eval_counts;
205205
for (int i = 0; i < sizeof...(args); ++i) {
206206
expr_eval_counts[i] = 0;
@@ -219,7 +219,7 @@ void check_expr_test(F&& f, Args&&... args) {
219219
[&f](auto&&... args) { return f(std::forward<decltype(args)>(args)...); },
220220
expr_args));
221221
expect_all_used_only_once<ScalarType>(expr_eval_counts, size_of_eigen_args);
222-
if (stan::is_var<ScalarType>::value) {
222+
if constexpr (stan::is_var<ScalarType>::value) {
223223
stan::math::recover_memory();
224224
}
225225
}
@@ -233,10 +233,10 @@ void check_expr_test(F&& f, Args&&... args) {
233233
* @param f functor whose `operator()` will be called.
234234
* @param args pack of arguments to pass to the functor.
235235
*/
236-
template <typename F, typename... Args,
236+
template <bool ReverseOnly = false, typename F, typename... Args,
237237
require_all_st_stan_scalar<Args...>* = nullptr,
238238
require_all_not_st_complex<Args...>* = nullptr>
239-
void check_expr_test(F&& f, Args&&... args) {
239+
inline void check_expr_test(F&& f, Args&&... args) {
240240
try {
241241
stan::test::internal::check_expr_test<double>(f, args...);
242242
try {
@@ -245,12 +245,14 @@ void check_expr_test(F&& f, Args&&... args) {
245245
} catch (const std::exception& e) {
246246
stan::math::recover_memory();
247247
}
248-
stan::test::internal::check_expr_test<stan::math::fvar<double>>(f, args...);
248+
if constexpr (!ReverseOnly) {
249+
stan::test::internal::check_expr_test<stan::math::fvar<double>>(f, args...);
250+
}
249251
} catch (const std::exception& e) {
250252
}
251253
}
252254

253-
template <typename F, typename... Args,
255+
template <bool ReverseOnly = false, typename F, typename... Args,
254256
require_any_st_complex<Args...>* = nullptr>
255257
void check_expr_test(F&& f, Args&&... args) {
256258
try {
@@ -262,8 +264,10 @@ void check_expr_test(F&& f, Args&&... args) {
262264
} catch (const std::exception& e) {
263265
stan::math::recover_memory();
264266
}
265-
stan::test::internal::check_expr_test<
266-
std::complex<stan::math::fvar<double>>>(f, args...);
267+
if constexpr (!ReverseOnly) {
268+
stan::test::internal::check_expr_test<
269+
std::complex<stan::math::fvar<double>>>(f, args...);
270+
}
267271
} catch (const std::exception& e) {
268272
}
269273
}

test/unit/math/test_ad.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ inline void expect_ad_helper(const ad_tolerances& tols, const F& f, const G& g,
490490
size_t result_size = 0;
491491
try {
492492
auto y1 = eval(f(xs...)); // original types, including int
493-
stan::test::check_expr_test(f, xs...);
493+
stan::test::check_expr_test<ReverseOnly>(f, xs...);
494494
auto y2 = eval(g(x)); // all int cast to double
495495
auto y1_serial = serialize<double>(y1);
496496
expect_near_rel("expect_ad_helper", y1_serial, y2, 1e-10);
@@ -551,7 +551,7 @@ inline void expect_ad_v(const ad_tolerances& tols, const F& f, int x) {
551551
// if f throws on int, must throw everywhere with double
552552
try {
553553
f(x);
554-
stan::test::check_expr_test(f, x);
554+
stan::test::check_expr_test<ReverseOnly>(f, x);
555555
} catch (...) {
556556
expect_all_throw<ReverseOnly>(f, x_dbl);
557557
return;
@@ -619,7 +619,7 @@ inline void expect_ad_vv(const ad_tolerances& tols, const F& f, int x1,
619619
const T2& x2) {
620620
try {
621621
f(x1, x2);
622-
stan::test::check_expr_test(f, x1, x2);
622+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2);
623623
} catch (...) {
624624
expect_all_throw<ReverseOnly>(f, x1, x2);
625625
return;
@@ -643,7 +643,7 @@ inline void expect_ad_vv(const ad_tolerances& tols, const F& f, const T1& x1,
643643
int x2) {
644644
try {
645645
f(x1, x2);
646-
stan::test::check_expr_test(f, x1, x2);
646+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2);
647647
} catch (...) {
648648
expect_all_throw<ReverseOnly>(f, x1, x2);
649649
return;
@@ -668,7 +668,7 @@ inline void expect_ad_vv(const ad_tolerances& tols, const F& f, int x1,
668668
// this one needs throw test because it's not handled by recursion
669669
try {
670670
f(x1, x2);
671-
stan::test::check_expr_test(f, x1, x2);
671+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2);
672672
} catch (...) {
673673
expect_all_throw<ReverseOnly>(f, x1, x2);
674674
return;
@@ -784,7 +784,7 @@ inline void expect_ad_vvv(const ad_tolerances& tols, const F& f, int x1, int x2,
784784
const T3& x3) {
785785
try {
786786
f(x1, x2, x3);
787-
stan::test::check_expr_test(f, x1, x2, x3);
787+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2, x3);
788788
} catch (...) {
789789
expect_all_throw<ReverseOnly>(f, x1, x2, x3);
790790
return;
@@ -814,7 +814,7 @@ inline void expect_ad_vvv(const ad_tolerances& tols, const F& f, int x1,
814814
const T2& x2, const T3& x3) {
815815
try {
816816
f(x1, x2, x3);
817-
stan::test::check_expr_test(f, x1, x2, x3);
817+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2, x3);
818818
} catch (...) {
819819
expect_all_throw<ReverseOnly>(f, x1, x2, x3);
820820
return;
@@ -839,7 +839,7 @@ inline void expect_ad_vvv(const ad_tolerances& tols, const F& f, const T1& x1,
839839
int x2, const T3& x3) {
840840
try {
841841
f(x1, x2, x3);
842-
stan::test::check_expr_test(f, x1, x2, x3);
842+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2, x3);
843843
} catch (...) {
844844
expect_all_throw<ReverseOnly>(f, x1, x2, x3);
845845
return;
@@ -864,7 +864,7 @@ inline void expect_ad_vvv(const ad_tolerances& tols, const F& f, const T1& x1,
864864
const T2& x2, int x3) {
865865
try {
866866
f(x1, x2, x3);
867-
stan::test::check_expr_test(f, x1, x2, x3);
867+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2, x3);
868868
} catch (...) {
869869
expect_all_throw<ReverseOnly>(f, x1, x2, x3);
870870
return;
@@ -889,7 +889,7 @@ inline void expect_ad_vvv(const ad_tolerances& tols, const F& f, int x1,
889889
const T2& x2, int x3) {
890890
try {
891891
f(x1, x2, x3);
892-
stan::test::check_expr_test(f, x1, x2, x3);
892+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2, x3);
893893
} catch (...) {
894894
expect_all_throw<ReverseOnly>(f, x1, x2, x3);
895895
return;
@@ -919,7 +919,7 @@ inline void expect_ad_vvv(const ad_tolerances& tols, const F& f, const T1& x1,
919919
int x2, int x3) {
920920
try {
921921
f(x1, x2, x3);
922-
stan::test::check_expr_test(f, x1, x2, x3);
922+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2, x3);
923923
} catch (...) {
924924
expect_all_throw<ReverseOnly>(f, x1, x2, x3);
925925
return;
@@ -950,7 +950,7 @@ inline void expect_ad_vvv(const ad_tolerances& tols, const F& f, int x1, int x2,
950950
// test exception behavior; other exception cases tested recursively
951951
try {
952952
f(x1, x2, x3);
953-
stan::test::check_expr_test(f, x1, x2, x3);
953+
stan::test::check_expr_test<ReverseOnly>(f, x1, x2, x3);
954954
} catch (...) {
955955
expect_all_throw<ReverseOnly>(f, x1, x2, x3);
956956
return;

0 commit comments

Comments
 (0)