88namespace stan {
99namespace math {
1010
11- inline void zero_adjoints ();
12-
13- template <typename T, typename ... Pargs, require_st_arithmetic<T>* = nullptr >
14- inline void zero_adjoints (T& x, Pargs&... args);
15-
16- template <typename ... Pargs>
17- inline void zero_adjoints (var& x, Pargs&... args);
18-
19- template <int R, int C, typename ... Pargs>
20- inline void zero_adjoints (Eigen::Matrix<var, R, C>& x, Pargs&... args);
21-
22- template <typename T, typename ... Pargs, require_st_autodiff<T>* = nullptr >
23- inline void zero_adjoints (std::vector<T>& x, Pargs&... args);
24-
2511/* *
2612 * End of recursion for set_zero_adjoints
2713 */
28- inline void zero_adjoints () {}
14+ inline void zero_adjoints () noexcept {}
2915
3016/* *
3117 * Do nothing for non-autodiff arguments. Recursively call zero_adjoints
@@ -37,10 +23,8 @@ inline void zero_adjoints() {}
3723 * @param x current argument
3824 * @param args rest of arguments to zero
3925 */
40- template <typename T, typename ... Pargs, require_st_arithmetic<T>*>
41- inline void zero_adjoints (T& x, Pargs&... args) {
42- zero_adjoints (args...);
43- }
26+ template <typename T, require_st_arithmetic<T>* = nullptr >
27+ inline void zero_adjoints (T& x) noexcept {}
4428
4529/* *
4630 * Zero the adjoint of the vari in the first argument. Recursively call
@@ -52,11 +36,7 @@ inline void zero_adjoints(T& x, Pargs&... args) {
5236 * @param x current argument
5337 * @param args rest of arguments to zero
5438 */
55- template <typename ... Pargs>
56- inline void zero_adjoints (var& x, Pargs&... args) {
57- x.vi_ ->set_zero_adjoint ();
58- zero_adjoints (args...);
59- }
39+ inline void zero_adjoints (var& x) { x.adj () = 0 ; }
6040
6141/* *
6242 * Zero the adjoints of the varis of every var in an Eigen::Matrix
@@ -68,11 +48,10 @@ inline void zero_adjoints(var& x, Pargs&... args) {
6848 * @param x current argument
6949 * @param args rest of arguments to zero
7050 */
71- template <int R, int C, typename ... Pargs >
72- inline void zero_adjoints (Eigen::Matrix<var, R, C> & x, Pargs&... args ) {
51+ template <typename EigMat, require_eigen_vt<is_autodiff, EigMat>* = nullptr >
52+ inline void zero_adjoints (EigMat & x) {
7353 for (size_t i = 0 ; i < x.size (); ++i)
74- x.coeffRef (i).vi_ ->set_zero_adjoint ();
75- zero_adjoints (args...);
54+ x.coeffRef (i).adj () = 0 ;
7655}
7756
7857/* *
@@ -85,11 +64,12 @@ inline void zero_adjoints(Eigen::Matrix<var, R, C>& x, Pargs&... args) {
8564 * @param x current argument
8665 * @param args rest of arguments to zero
8766 */
88- template <typename T, typename ... Pargs, require_st_autodiff<T>*>
89- inline void zero_adjoints (std::vector<T>& x, Pargs&... args) {
90- for (size_t i = 0 ; i < x.size (); ++i)
67+ template <typename StdVec,
68+ require_std_vector_st<is_autodiff, StdVec>* = nullptr >
69+ inline void zero_adjoints (StdVec& x) {
70+ for (size_t i = 0 ; i < x.size (); ++i) {
9171 zero_adjoints (x[i]);
92- zero_adjoints (args...);
72+ }
9373}
9474
9575} // namespace math
0 commit comments