Skip to content

Commit 43fb63b

Browse files
committed
Switched integrate_ode_bdf, integrate_ode_adams, and integrate_ode_rk45 to use functors compatible with previous release
1 parent a37be47 commit 43fb63b

13 files changed

+52
-26
lines changed

stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct integrate_ode_std_vector_interface_adapter {
3030
std::ostream* msgs, const std::vector<T2>& theta,
3131
const std::vector<double>& x,
3232
const std::vector<int>& x_int) const {
33-
return to_vector(f_(t, to_array_1d(y), msgs, theta, x, x_int));
33+
return to_vector(f_(t, to_array_1d(y), theta, x, x_int, msgs));
3434
}
3535
};
3636

test/unit/math/prim/functor/forced_harmonic_oscillator.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ struct forced_harm_osc_ode_fun {
1313
// parameters
1414
// double data
1515
// integer data
16-
operator()(const T0& t_in, const std::vector<T1>& y_in, std::ostream* msgs,
16+
operator()(const T0& t_in, const std::vector<T1>& y_in,
1717
const std::vector<T2>& theta, const std::vector<double>& x,
18-
const std::vector<int>& x_int) const {
18+
const std::vector<int>& x_int, std::ostream* msgs) const {
1919
if (y_in.size() != 2)
2020
throw std::domain_error(
2121
"this function was called with inconsistent state");

test/unit/math/prim/functor/harmonic_oscillator.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ struct harm_osc_ode_fun {
1313
// parameters
1414
// double data
1515
// integer data
16-
operator()(const T0& t_in, const std::vector<T1>& y_in, std::ostream* msgs,
16+
operator()(const T0& t_in, const std::vector<T1>& y_in,
1717
const std::vector<T2>& theta, const std::vector<double>& x,
18-
const std::vector<int>& x_int) const {
18+
const std::vector<int>& x_int, std::ostream* msgs) const {
1919
if (y_in.size() != 2)
2020
throw std::domain_error(
2121
"this function was called with inconsistent state");
@@ -55,9 +55,9 @@ struct harm_osc_ode_data_fun {
5555
// parameters
5656
// double data
5757
// integer data
58-
operator()(const T0& t_in, const std::vector<T1>& y_in, std::ostream* msgs,
58+
operator()(const T0& t_in, const std::vector<T1>& y_in,
5959
const std::vector<T2>& theta, const std::vector<double>& x,
60-
const std::vector<int>& x_int) const {
60+
const std::vector<int>& x_int, std::ostream* msgs) const {
6161
if (y_in.size() != 2)
6262
throw std::domain_error(
6363
"this function was called with inconsistent state");
@@ -79,9 +79,9 @@ struct harm_osc_ode_wrong_size_1_fun {
7979
// parameters
8080
// double data
8181
// integer data
82-
operator()(const T0& t_in, const std::vector<T1>& y_in, std::ostream* msgs,
82+
operator()(const T0& t_in, const std::vector<T1>& y_in,
8383
const std::vector<T2>& theta, const std::vector<double>& x,
84-
const std::vector<int>& x_int) const {
84+
const std::vector<int>& x_int, std::ostream* msgs) const {
8585
if (y_in.size() != 2)
8686
throw std::domain_error(
8787
"this function was called with inconsistent state");
@@ -103,9 +103,9 @@ struct harm_osc_ode_wrong_size_2_fun {
103103
// parameters
104104
// double data
105105
// integer data
106-
operator()(const T0& t_in, const std::vector<T1>& y_in, std::ostream* msgs,
106+
operator()(const T0& t_in, const std::vector<T1>& y_in,
107107
const std::vector<T2>& theta, const std::vector<double>& x,
108-
const std::vector<int>& x_int) const {
108+
const std::vector<int>& x_int, std::ostream* msgs) const {
109109
if (y_in.size() != 2)
110110
throw std::domain_error(
111111
"this function was called with inconsistent state");

test/unit/math/prim/functor/integrate_ode_std_vector_interface_adapter_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ TEST(StanMath, check_values) {
1919
double t = 1.0;
2020

2121
Eigen::VectorXd out1
22-
= stan::math::to_vector(harm_osc(t, y, nullptr, theta, x, x_int));
22+
= stan::math::to_vector(harm_osc(t, y, theta, x, x_int, nullptr));
2323
Eigen::VectorXd out2
2424
= harm_osc_adapted(t, stan::math::to_vector(y), nullptr, theta, x, x_int);
2525

test/unit/math/prim/functor/lorenz.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ struct lorenz_ode_fun {
2929
// parameters
3030
// double data
3131
// integer data
32-
operator()(const T0& t_in, const std::vector<T1>& y_in, std::ostream* msgs,
32+
operator()(const T0& t_in, const std::vector<T1>& y_in,
3333
const std::vector<T2>& theta, const std::vector<double>& x,
34-
const std::vector<int>& x_int) const {
34+
const std::vector<int>& x_int, std::ostream* msgs) const {
3535
return lorenz_ode(t_in, y_in, theta, x, x_int);
3636
}
3737
};

test/unit/math/rev/functor/coupled_mm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ struct coupled_mm_ode_fun {
1212
// parameters
1313
// double data
1414
// integer data
15-
operator()(const T0& t_in, const std::vector<T1>& y, std::ostream* msgs,
15+
operator()(const T0& t_in, const std::vector<T1>& y,
1616
const std::vector<T2>& parms, const std::vector<double>& sx,
17-
const std::vector<int>& sx_int) const {
17+
const std::vector<int>& sx_int, std::ostream* msgs) const {
1818
std::vector<stan::return_type_t<T1, T2>> ydot(2);
1919

2020
const T2 act = parms[0];
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <stan/math/rev.hpp>
2+
#include <gtest/gtest.h>
3+
#include <test/unit/util.hpp>
4+
#include <iostream>
5+
#include <vector>
6+
7+
struct Inverse {
8+
template <typename T0, typename T_y>
9+
inline Eigen::Matrix<stan::return_type_t<T_y>, Eigen::Dynamic, 1>
10+
operator()(const T0& t, const Eigen::Matrix<T_y, Eigen::Dynamic, 1>& y, std::ostream* msgs) const {
11+
Eigen::Matrix<T_y, Eigen::Dynamic, 1> out(1);
12+
out(0) = 1.0 / (y(0) - 1.0);
13+
return out;
14+
}
15+
};
16+
17+
TEST(StanMath, cvodes_error_handler) {
18+
Eigen::VectorXd y0 = Eigen::VectorXd::Ones(1);
19+
int t0 = 0;
20+
std::vector<double> ts = {0.45, 1.1};
21+
22+
std::string msg = "CVODES: CVode Internal t = 0 and h = 0 are such that t + h = t on the next step";
23+
24+
EXPECT_THROW_MSG(stan::math::ode_bdf(Inverse(), y0, t0, ts, nullptr),
25+
std::domain_error, msg);
26+
}

test/unit/math/rev/functor/integrate_ode_adams_rev_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ TEST(StanAgradRevOde_integrate_ode_adams, time_steps_as_param_AD) {
215215
} else {
216216
std::vector<double> y0(res_d.begin(), res_d.begin() + ns);
217217
EXPECT_FLOAT_EQ(g[k],
218-
ode(ts[i].val(), y0, msgs, theta, x, x_int)[j]);
218+
ode(ts[i].val(), y0, theta, x, x_int, msgs)[j]);
219219
}
220220
}
221221
stan::math::set_zero_all_adjoints();

test/unit/math/rev/functor/integrate_ode_bdf_rev_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ TEST(StanAgradRevOde_integrate_ode_bdf, time_steps_as_param_AD) {
245245
} else {
246246
std::vector<double> y0(res_d.begin(), res_d.begin() + ns);
247247
EXPECT_FLOAT_EQ(ts[k].adj(),
248-
ode(ts[i].val(), y0, msgs, theta, x, x_int)[j]);
248+
ode(ts[i].val(), y0, theta, x, x_int, msgs)[j]);
249249
}
250250
}
251251
stan::math::set_zero_all_adjoints();

test/unit/math/rev/functor/integrate_ode_cvodes_grad_rev_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class sho_functor {
3838
inline std::vector<stan::return_type_t<T1, T2>> operator()(
3939
const T0& t_in, // time
4040
const std::vector<T1>& y_in, // state
41-
std::ostream* msgs, // error stream
4241
const std::vector<T2>& theta, // parameters
4342
const std::vector<double>& x, // double data
44-
const std::vector<int>& x_int) const { // integer data
43+
const std::vector<int>& x_int, // integer data
44+
std::ostream* msgs) const { // error stream
4545
if (y_in.size() != 2)
4646
throw std::domain_error("Functor called with inconsistent state");
4747

0 commit comments

Comments
 (0)