Skip to content

Commit 44529fd

Browse files
committed
Adjoint ode_bdf passing ode_bdf tests. integrate_ode_bdf not working yet. (design-doc #19)
1 parent 9831634 commit 44529fd

File tree

7 files changed

+463
-153
lines changed

7 files changed

+463
-153
lines changed

stan/math/prim/functor/ode_rk45.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,6 @@ ode_rk45_tol(const F& f,
163163
bool observer_initial_recorded = false;
164164
size_t time_index = 0;
165165

166-
max_step_checker step_checker(max_num_steps);
167-
168166
// avoid recording of the initial state which is included by the
169167
// conventions of odeint in the output
170168
auto filtered_observer
@@ -175,7 +173,6 @@ ode_rk45_tol(const F& f,
175173
}
176174
y.emplace_back(ode_store_sensitivities(f, coupled_state, y0, t0,
177175
ts[time_index], msgs, args...));
178-
step_checker.reset();
179176
time_index++;
180177
};
181178

@@ -189,7 +186,8 @@ ode_rk45_tol(const F& f,
189186
runge_kutta_dopri5<Eigen::VectorXd, double, Eigen::VectorXd, double,
190187
vector_space_algebra>()),
191188
std::ref(coupled_system), initial_coupled_state, std::begin(ts_vec),
192-
std::end(ts_vec), step_size, filtered_observer, step_checker);
189+
std::end(ts_vec), step_size, filtered_observer,
190+
max_step_checker(max_num_steps));
193191

194192
return y;
195193
}

stan/math/rev/core/grad.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/rev/core/nested_size.hpp>
88
#include <stan/math/rev/core/vari.hpp>
99
#include <vector>
10+
#include <iostream>
1011

1112
namespace stan {
1213
namespace math {
@@ -27,8 +28,9 @@ namespace math {
2728
* @param vi Variable implementation for root of partial
2829
* derivative propagation.
2930
*/
30-
static void grad(vari* vi) {
31-
vi->init_dependent();
31+
static void grad(vari* vi = NULL) {
32+
if(vi != NULL)
33+
vi->init_dependent();
3234
std::vector<vari*>& var_stack = ChainableStack::instance_->var_stack_;
3335
size_t end = var_stack.size();
3436
size_t beginning = empty_nested() ? 0 : end - nested_size();

0 commit comments

Comments
 (0)