Skip to content

Commit 5534e22

Browse files
authored
Merge pull request #1905 from stan-dev/feature/adjoint-odes
Feature/adjoint odes implements #2486
2 parents 1c31a77 + 01c9fa1 commit 5534e22

23 files changed

+1536
-78
lines changed

make/tests

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ $(EXPRESSION_TESTS) : $(LIBSUNDIALS)
6464
# CVODES tests
6565
##
6666

67-
CVODES_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*cvodes*_test.cpp) $(call findfiles,test,*_bdf_*_test.cpp) $(call findfiles,test,*_adams_*_test.cpp) $(call findfiles,test,*_ode_typed_*test.cpp))
67+
CVODES_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*cvodes*_test.cpp) $(call findfiles,test,*_bdf_*_test.cpp) $(call findfiles,test,*_adams_*_test.cpp) $(call findfiles,test,*_ode_typed_*test.cpp) $(call findfiles,test,*_ode_adjoint_typed_*test.cpp))
6868
$(CVODES_TESTS) : $(LIBSUNDIALS)
6969

7070

stan/math/rev/core/zero_adjoints.hpp

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,10 @@
88
namespace stan {
99
namespace math {
1010

11-
inline void zero_adjoints();
12-
13-
template <typename T, typename... Pargs, require_st_arithmetic<T>* = nullptr>
14-
inline void zero_adjoints(T& x, Pargs&... args);
15-
16-
template <typename... Pargs>
17-
inline void zero_adjoints(var& x, Pargs&... args);
18-
19-
template <int R, int C, typename... Pargs>
20-
inline void zero_adjoints(Eigen::Matrix<var, R, C>& x, Pargs&... args);
21-
22-
template <typename T, typename... Pargs, require_st_autodiff<T>* = nullptr>
23-
inline void zero_adjoints(std::vector<T>& x, Pargs&... args);
24-
2511
/**
2612
* End of recursion for set_zero_adjoints
2713
*/
28-
inline void zero_adjoints() {}
14+
inline void zero_adjoints() noexcept {}
2915

3016
/**
3117
* Do nothing for non-autodiff arguments. Recursively call zero_adjoints
@@ -37,10 +23,8 @@ inline void zero_adjoints() {}
3723
* @param x current argument
3824
* @param args rest of arguments to zero
3925
*/
40-
template <typename T, typename... Pargs, require_st_arithmetic<T>*>
41-
inline void zero_adjoints(T& x, Pargs&... args) {
42-
zero_adjoints(args...);
43-
}
26+
template <typename T, require_st_arithmetic<T>* = nullptr>
27+
inline void zero_adjoints(T& x) noexcept {}
4428

4529
/**
4630
* Zero the adjoint of the vari in the first argument. Recursively call
@@ -52,11 +36,7 @@ inline void zero_adjoints(T& x, Pargs&... args) {
5236
* @param x current argument
5337
* @param args rest of arguments to zero
5438
*/
55-
template <typename... Pargs>
56-
inline void zero_adjoints(var& x, Pargs&... args) {
57-
x.vi_->set_zero_adjoint();
58-
zero_adjoints(args...);
59-
}
39+
inline void zero_adjoints(var& x) { x.adj() = 0; }
6040

6141
/**
6242
* Zero the adjoints of the varis of every var in an Eigen::Matrix
@@ -68,11 +48,10 @@ inline void zero_adjoints(var& x, Pargs&... args) {
6848
* @param x current argument
6949
* @param args rest of arguments to zero
7050
*/
71-
template <int R, int C, typename... Pargs>
72-
inline void zero_adjoints(Eigen::Matrix<var, R, C>& x, Pargs&... args) {
51+
template <typename EigMat, require_eigen_vt<is_autodiff, EigMat>* = nullptr>
52+
inline void zero_adjoints(EigMat& x) {
7353
for (size_t i = 0; i < x.size(); ++i)
74-
x.coeffRef(i).vi_->set_zero_adjoint();
75-
zero_adjoints(args...);
54+
x.coeffRef(i).adj() = 0;
7655
}
7756

7857
/**
@@ -85,11 +64,12 @@ inline void zero_adjoints(Eigen::Matrix<var, R, C>& x, Pargs&... args) {
8564
* @param x current argument
8665
* @param args rest of arguments to zero
8766
*/
88-
template <typename T, typename... Pargs, require_st_autodiff<T>*>
89-
inline void zero_adjoints(std::vector<T>& x, Pargs&... args) {
90-
for (size_t i = 0; i < x.size(); ++i)
67+
template <typename StdVec,
68+
require_std_vector_st<is_autodiff, StdVec>* = nullptr>
69+
inline void zero_adjoints(StdVec& x) {
70+
for (size_t i = 0; i < x.size(); ++i) {
9171
zero_adjoints(x[i]);
92-
zero_adjoints(args...);
72+
}
9373
}
9474

9575
} // namespace math

stan/math/rev/functor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <stan/math/rev/functor/integrate_ode_bdf.hpp>
1818
#include <stan/math/rev/functor/ode_adams.hpp>
1919
#include <stan/math/rev/functor/ode_bdf.hpp>
20+
#include <stan/math/rev/functor/ode_adjoint.hpp>
2021
#include <stan/math/rev/functor/ode_store_sensitivities.hpp>
2122
#include <stan/math/rev/functor/jacobian.hpp>
2223
#include <stan/math/rev/functor/kinsol_data.hpp>

stan/math/rev/functor/coupled_ode_system.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/rev/meta.hpp>
77
#include <stan/math/rev/core.hpp>
88
#include <stan/math/rev/fun/value_of.hpp>
9+
#include <stan/math/prim/functor/for_each.hpp>
910
#include <stan/math/prim/err.hpp>
1011
#include <stdexcept>
1112
#include <ostream>
@@ -137,8 +138,10 @@ struct coupled_ode_system_impl<false, F, T_y0, Args...> {
137138

138139
y_adjoints_ = y_vars.adj();
139140

140-
// memset was faster than Eigen setZero
141-
memset(args_adjoints_.data(), 0, sizeof(double) * num_args_vars);
141+
if (args_adjoints_.size() > 0) {
142+
memset(args_adjoints_.data(), 0,
143+
sizeof(double) * args_adjoints_.size());
144+
}
142145

143146
apply(
144147
[&](auto&&... args) {
@@ -148,7 +151,8 @@ struct coupled_ode_system_impl<false, F, T_y0, Args...> {
148151

149152
// The vars here do not live on the nested stack so must be zero'd
150153
// separately
151-
apply([&](auto&&... args) { zero_adjoints(args...); }, local_args_tuple_);
154+
stan::math::for_each([](auto&& arg) { zero_adjoints(arg); },
155+
local_args_tuple_);
152156

153157
// No need to zero adjoints after last sweep
154158
if (i + 1 < N_) {

stan/math/rev/functor/cvodes_integrator.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,11 @@ class cvodes_integrator {
279279
CVodeSetUserData(cvodes_mem, reinterpret_cast<void*>(this)),
280280
"CVodeSetUserData");
281281

282-
cvodes_set_options(cvodes_mem, relative_tolerance_, absolute_tolerance_,
283-
max_num_steps_);
282+
cvodes_set_options(cvodes_mem, max_num_steps_);
283+
284+
check_flag_sundials(CVodeSStolerances(cvodes_mem, relative_tolerance_,
285+
absolute_tolerance_),
286+
"CVodeSStolerances");
284287

285288
check_flag_sundials(CVodeSetLinearSolver(cvodes_mem, LS_, A_),
286289
"CVodeSetLinearSolver");

0 commit comments

Comments
 (0)